diff --git a/changelog.md b/changelog.md index 8fe144056a..bc3070f7e1 100644 --- a/changelog.md +++ b/changelog.md @@ -51,6 +51,7 @@ * [1879](https://github.com/zeta-chain/node/pull/1879) - full coverage for messages in types packages * [1899](https://github.com/zeta-chain/node/pull/1899) - add empty test files so packages are included in coverage * [1903](https://github.com/zeta-chain/node/pull/1903) - common package tests +* [1961](https://github.com/zeta-chain/node/pull/1961) - improve observer module coverage * [1967](https://github.com/zeta-chain/node/pull/1967) - improve crosschain module coverage * [1955](https://github.com/zeta-chain/node/pull/1955) - improve emissions module coverage diff --git a/proto/observer/ballot.proto b/proto/observer/ballot.proto index c89b50bb83..aa871cff35 100644 --- a/proto/observer/ballot.proto +++ b/proto/observer/ballot.proto @@ -20,6 +20,7 @@ enum BallotStatus { BallotInProgress = 2; } +// https://github.com/zeta-chain/node/issues/939 message Ballot { string index = 1; string ballot_identifier = 2; diff --git a/testutil/keeper/mocks/observer/staking.go b/testutil/keeper/mocks/observer/staking.go index 90007b6c35..72bf99599f 100644 --- a/testutil/keeper/mocks/observer/staking.go +++ b/testutil/keeper/mocks/observer/staking.go @@ -71,6 +71,11 @@ func (_m *ObserverStakingKeeper) GetValidator(ctx types.Context, addr types.ValA return r0, r1 } +// SetDelegation provides a mock function with given fields: ctx, delegation +func (_m *ObserverStakingKeeper) SetDelegation(ctx types.Context, delegation stakingtypes.Delegation) { + _m.Called(ctx, delegation) +} + // SetValidator provides a mock function with given fields: ctx, validator func (_m *ObserverStakingKeeper) SetValidator(ctx types.Context, validator stakingtypes.Validator) { _m.Called(ctx, validator) diff --git a/typescript/observer/ballot_pb.d.ts b/typescript/observer/ballot_pb.d.ts index 100a1a0f1d..1eb7d2f01d 100644 --- a/typescript/observer/ballot_pb.d.ts +++ b/typescript/observer/ballot_pb.d.ts @@ -50,6 +50,8 @@ export declare enum BallotStatus { } /** + * https://github.com/zeta-chain/node/issues/939 + * * @generated from message zetachain.zetacore.observer.Ballot */ export declare class Ballot extends Message { diff --git a/x/observer/abci.go b/x/observer/abci.go index f219bb833d..41121ced19 100644 --- a/x/observer/abci.go +++ b/x/observer/abci.go @@ -21,10 +21,6 @@ func BeginBlocker(ctx sdk.Context, k keeper.Keeper) { return } totalObserverCountCurrentBlock := allObservers.LenUint() - if totalObserverCountCurrentBlock < 0 { - ctx.Logger().Error("TotalObserverCount is negative at height", ctx.BlockHeight()) - return - } // #nosec G701 always in range if totalObserverCountCurrentBlock == lastBlockObserverCount.Count { return diff --git a/x/observer/abci_test.go b/x/observer/abci_test.go new file mode 100644 index 0000000000..02a161d6b6 --- /dev/null +++ b/x/observer/abci_test.go @@ -0,0 +1,98 @@ +package observer_test + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/testutil/sample" + "github.com/zeta-chain/zetacore/x/observer" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestBeginBlocker(t *testing.T) { + t.Run("should not update LastObserverCount if not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + observer.BeginBlocker(ctx, *k) + + _, found := k.GetLastObserverCount(ctx) + require.False(t, found) + + _, found = k.GetKeygen(ctx) + require.False(t, found) + }) + + t.Run("should not update LastObserverCount if observer set not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + count := 1 + k.SetLastObserverCount(ctx, &types.LastObserverCount{ + Count: uint64(count), + }) + + observer.BeginBlocker(ctx, *k) + + lastObserverCount, found := k.GetLastObserverCount(ctx) + require.True(t, found) + require.Equal(t, uint64(count), lastObserverCount.Count) + require.Equal(t, int64(0), lastObserverCount.LastChangeHeight) + + _, found = k.GetKeygen(ctx) + require.False(t, found) + }) + + t.Run("should not update LastObserverCount if observer set count equal last observed count", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + count := 1 + os := sample.ObserverSet(count) + k.SetObserverSet(ctx, os) + k.SetLastObserverCount(ctx, &types.LastObserverCount{ + Count: uint64(count), + }) + + observer.BeginBlocker(ctx, *k) + + lastObserverCount, found := k.GetLastObserverCount(ctx) + require.True(t, found) + require.Equal(t, uint64(count), lastObserverCount.Count) + require.Equal(t, int64(0), lastObserverCount.LastChangeHeight) + + _, found = k.GetKeygen(ctx) + require.False(t, found) + }) + + t.Run("should update LastObserverCount", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + observeSetLen := 10 + count := 1 + os := sample.ObserverSet(observeSetLen) + k.SetObserverSet(ctx, os) + k.SetLastObserverCount(ctx, &types.LastObserverCount{ + Count: uint64(count), + }) + + keygen, found := k.GetKeygen(ctx) + require.False(t, found) + require.Equal(t, types.Keygen{}, keygen) + + observer.BeginBlocker(ctx, *k) + + keygen, found = k.GetKeygen(ctx) + require.True(t, found) + require.Empty(t, keygen.GranteePubkeys) + require.Equal(t, types.KeygenStatus_PendingKeygen, keygen.Status) + require.Equal(t, int64(math.MaxInt64), keygen.BlockNumber) + + inboundEnabled := k.IsInboundEnabled(ctx) + require.False(t, inboundEnabled) + + lastObserverCount, found := k.GetLastObserverCount(ctx) + require.True(t, found) + require.Equal(t, uint64(observeSetLen), lastObserverCount.Count) + require.Equal(t, ctx.BlockHeight(), lastObserverCount.LastChangeHeight) + }) +} diff --git a/x/observer/genesis.go b/x/observer/genesis.go index 7e3ff44311..b6d85757ff 100644 --- a/x/observer/genesis.go +++ b/x/observer/genesis.go @@ -16,7 +16,7 @@ func InitGenesis(ctx sdk.Context, k keeper.Keeper, genState types.GenesisState) observerCount = uint64(len(genState.Observers.ObserverList)) } - // if chian params are defined set them + // if chain params are defined set them if len(genState.ChainParamsList.ChainParams) > 0 { k.SetChainParamsList(ctx, genState.ChainParamsList) } else { diff --git a/x/observer/genesis_test.go b/x/observer/genesis_test.go index 7482f24f95..62e442e9ba 100644 --- a/x/observer/genesis_test.go +++ b/x/observer/genesis_test.go @@ -12,45 +12,151 @@ import ( ) func TestGenesis(t *testing.T) { - params := types.DefaultParams() - tss := sample.Tss() - genesisState := types.GenesisState{ - Params: ¶ms, - Tss: &tss, - BlameList: sample.BlameRecordsList(t, 10), - Ballots: []*types.Ballot{ - sample.Ballot(t, "0"), - sample.Ballot(t, "1"), - sample.Ballot(t, "2"), - }, - Observers: sample.ObserverSet(3), - NodeAccountList: []*types.NodeAccount{ - sample.NodeAccount(), - sample.NodeAccount(), - sample.NodeAccount(), - }, - CrosschainFlags: types.DefaultCrosschainFlags(), - Keygen: sample.Keygen(t), - ChainParamsList: sample.ChainParamsList(), - LastObserverCount: sample.LastObserverCount(10), - TssFundMigrators: []types.TssFundMigratorInfo{sample.TssFundsMigrator(1), sample.TssFundsMigrator(2)}, - ChainNonces: []types.ChainNonces{ - sample.ChainNonces(t, "0"), - sample.ChainNonces(t, "1"), - sample.ChainNonces(t, "2"), - }, - PendingNonces: sample.PendingNoncesList(t, "sample", 20), - NonceToCctx: sample.NonceToCctxList(t, "sample", 20), - } - - // Init and export - k, ctx, _, _ := keepertest.ObserverKeeper(t) - observer.InitGenesis(ctx, *k, genesisState) - got := observer.ExportGenesis(ctx, *k) - require.NotNil(t, got) - - // Compare genesis after init and export - nullify.Fill(&genesisState) - nullify.Fill(got) - require.Equal(t, genesisState, *got) + t.Run("genState fields defined", func(t *testing.T) { + params := types.DefaultParams() + tss := sample.Tss() + genesisState := types.GenesisState{ + Params: ¶ms, + Tss: &tss, + BlameList: sample.BlameRecordsList(t, 10), + Ballots: []*types.Ballot{ + sample.Ballot(t, "0"), + sample.Ballot(t, "1"), + sample.Ballot(t, "2"), + }, + Observers: sample.ObserverSet(3), + NodeAccountList: []*types.NodeAccount{ + sample.NodeAccount(), + sample.NodeAccount(), + sample.NodeAccount(), + }, + CrosschainFlags: types.DefaultCrosschainFlags(), + Keygen: sample.Keygen(t), + ChainParamsList: sample.ChainParamsList(), + LastObserverCount: sample.LastObserverCount(10), + TssFundMigrators: []types.TssFundMigratorInfo{sample.TssFundsMigrator(1), sample.TssFundsMigrator(2)}, + ChainNonces: []types.ChainNonces{ + sample.ChainNonces(t, "0"), + sample.ChainNonces(t, "1"), + sample.ChainNonces(t, "2"), + }, + PendingNonces: sample.PendingNoncesList(t, "sample", 20), + NonceToCctx: sample.NonceToCctxList(t, "sample", 20), + TssHistory: []types.TSS{sample.Tss()}, + } + + // Init and export + k, ctx, _, _ := keepertest.ObserverKeeper(t) + observer.InitGenesis(ctx, *k, genesisState) + got := observer.ExportGenesis(ctx, *k) + require.NotNil(t, got) + + // Compare genesis after init and export + nullify.Fill(&genesisState) + nullify.Fill(got) + require.Equal(t, genesisState, *got) + }) + + t.Run("genState fields not defined", func(t *testing.T) { + genesisState := types.GenesisState{} + + k, ctx, _, _ := keepertest.ObserverKeeper(t) + observer.InitGenesis(ctx, *k, genesisState) + got := observer.ExportGenesis(ctx, *k) + require.NotNil(t, got) + + defaultParams := types.DefaultParams() + btcChainParams := types.GetDefaultBtcRegtestChainParams() + btcChainParams.IsSupported = true + goerliChainParams := types.GetDefaultGoerliLocalnetChainParams() + goerliChainParams.IsSupported = true + zetaPrivnetChainParams := types.GetDefaultZetaPrivnetChainParams() + zetaPrivnetChainParams.IsSupported = true + localnetChainParams := types.ChainParamsList{ + ChainParams: []*types.ChainParams{ + btcChainParams, + goerliChainParams, + zetaPrivnetChainParams, + }, + } + expectedGenesisState := types.GenesisState{ + Params: &defaultParams, + CrosschainFlags: types.DefaultCrosschainFlags(), + ChainParamsList: localnetChainParams, + Tss: &types.TSS{}, + Keygen: &types.Keygen{}, + LastObserverCount: &types.LastObserverCount{}, + NodeAccountList: []*types.NodeAccount{}, + } + + require.Equal(t, expectedGenesisState, *got) + }) + + t.Run("genState fields not defined except tss", func(t *testing.T) { + tss := sample.Tss() + genesisState := types.GenesisState{ + Tss: &tss, + } + + k, ctx, _, _ := keepertest.ObserverKeeper(t) + observer.InitGenesis(ctx, *k, genesisState) + got := observer.ExportGenesis(ctx, *k) + require.NotNil(t, got) + + defaultParams := types.DefaultParams() + btcChainParams := types.GetDefaultBtcRegtestChainParams() + btcChainParams.IsSupported = true + goerliChainParams := types.GetDefaultGoerliLocalnetChainParams() + goerliChainParams.IsSupported = true + zetaPrivnetChainParams := types.GetDefaultZetaPrivnetChainParams() + zetaPrivnetChainParams.IsSupported = true + localnetChainParams := types.ChainParamsList{ + ChainParams: []*types.ChainParams{ + btcChainParams, + goerliChainParams, + zetaPrivnetChainParams, + }, + } + pendingNonces, err := k.GetAllPendingNonces(ctx) + require.NoError(t, err) + require.NotEmpty(t, pendingNonces) + expectedGenesisState := types.GenesisState{ + Params: &defaultParams, + CrosschainFlags: types.DefaultCrosschainFlags(), + ChainParamsList: localnetChainParams, + Tss: &tss, + Keygen: &types.Keygen{}, + LastObserverCount: &types.LastObserverCount{}, + NodeAccountList: []*types.NodeAccount{}, + PendingNonces: pendingNonces, + } + + require.Equal(t, expectedGenesisState, *got) + }) + + t.Run("export without init", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + got := observer.ExportGenesis(ctx, *k) + require.NotNil(t, got) + + params := k.GetParamsIfExists(ctx) + expectedGenesisState := types.GenesisState{ + Params: ¶ms, + CrosschainFlags: types.DefaultCrosschainFlags(), + ChainParamsList: types.ChainParamsList{}, + Tss: &types.TSS{}, + Keygen: &types.Keygen{}, + LastObserverCount: &types.LastObserverCount{}, + NodeAccountList: []*types.NodeAccount{}, + Ballots: k.GetAllBallots(ctx), + TssHistory: k.GetAllTSS(ctx), + TssFundMigrators: k.GetAllTssFundMigrators(ctx), + BlameList: k.GetAllBlame(ctx), + ChainNonces: k.GetAllChainNonces(ctx), + NonceToCctx: k.GetAllNonceToCctx(ctx), + } + + require.Equal(t, expectedGenesisState, *got) + }) } diff --git a/x/observer/keeper/ballot_test.go b/x/observer/keeper/ballot_test.go index 13415cdaac..8d71aaa720 100644 --- a/x/observer/keeper/ballot_test.go +++ b/x/observer/keeper/ballot_test.go @@ -1,23 +1,162 @@ -package keeper +package keeper_test import ( + "math" "testing" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/testutil/sample" "github.com/zeta-chain/zetacore/x/observer/types" ) func TestKeeper_GetBallot(t *testing.T) { - k, ctx := SetupKeeper(t) - identifier := "0x9ea007f0f60e32d58577a8cf25678942d2b10791c2a34f48e237b76a7e998e4d" - k.SetBallot(ctx, &types.Ballot{ - Index: "", - BallotIdentifier: identifier, - VoterList: nil, - ObservationType: 0, - BallotThreshold: sdk.Dec{}, - BallotStatus: 0, + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + identifier := sample.ZetaIndex(t) + b := &types.Ballot{ + Index: "123", + BallotIdentifier: identifier, + VoterList: nil, + ObservationType: 0, + BallotThreshold: sdk.Dec{}, + BallotStatus: 0, + BallotCreationHeight: 1, + } + _, found := k.GetBallot(ctx, identifier) + require.False(t, found) + + k.SetBallot(ctx, b) + + ballot, found := k.GetBallot(ctx, identifier) + require.True(t, found) + require.Equal(t, *b, ballot) + + // overwrite existing ballot + b = &types.Ballot{ + Index: "123", + BallotIdentifier: identifier, + VoterList: nil, + ObservationType: 1, + BallotThreshold: sdk.Dec{}, + BallotStatus: 1, + BallotCreationHeight: 2, + } + _, found = k.GetBallot(ctx, identifier) + require.True(t, found) + + k.SetBallot(ctx, b) + + ballot, found = k.GetBallot(ctx, identifier) + require.True(t, found) + require.Equal(t, *b, ballot) +} + +func TestKeeper_GetBallotList(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + identifier := sample.ZetaIndex(t) + b := &types.Ballot{ + Index: "", + BallotIdentifier: identifier, + VoterList: nil, + ObservationType: 0, + BallotThreshold: sdk.Dec{}, + BallotStatus: 0, + BallotCreationHeight: 1, + } + _, found := k.GetBallotList(ctx, 1) + require.False(t, found) + + k.AddBallotToList(ctx, *b) + list, found := k.GetBallotList(ctx, 1) + require.True(t, found) + require.Equal(t, 1, len(list.BallotsIndexList)) + require.Equal(t, identifier, list.BallotsIndexList[0]) +} + +func TestKeeper_GetAllBallots(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + identifier := sample.ZetaIndex(t) + b := &types.Ballot{ + Index: "", + BallotIdentifier: identifier, + VoterList: nil, + ObservationType: 0, + BallotThreshold: sdk.Dec{}, + BallotStatus: 0, + BallotCreationHeight: 1, + } + ballots := k.GetAllBallots(ctx) + require.Empty(t, ballots) + + k.SetBallot(ctx, b) + ballots = k.GetAllBallots(ctx) + require.Equal(t, 1, len(ballots)) + require.Equal(t, b, ballots[0]) +} + +func TestKeeper_GetMaturedBallotList(t *testing.T) { + t.Run("should return if maturity blocks less than height", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + identifier := sample.ZetaIndex(t) + b := &types.Ballot{ + Index: "", + BallotIdentifier: identifier, + VoterList: nil, + ObservationType: 0, + BallotThreshold: sdk.Dec{}, + BallotStatus: 0, + BallotCreationHeight: 1, + } + list := k.GetMaturedBallotList(ctx) + require.Empty(t, list) + ctx = ctx.WithBlockHeight(101) + k.AddBallotToList(ctx, *b) + list = k.GetMaturedBallotList(ctx) + require.Equal(t, 1, len(list)) + require.Equal(t, identifier, list[0]) + }) + + t.Run("should return empty for max maturity blocks", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + identifier := sample.ZetaIndex(t) + b := &types.Ballot{ + Index: "", + BallotIdentifier: identifier, + VoterList: nil, + ObservationType: 0, + BallotThreshold: sdk.Dec{}, + BallotStatus: 0, + BallotCreationHeight: 1, + } + k.SetParams(ctx, types.Params{ + BallotMaturityBlocks: math.MaxInt64, + }) + list := k.GetMaturedBallotList(ctx) + require.Empty(t, list) + k.AddBallotToList(ctx, *b) + list = k.GetMaturedBallotList(ctx) + require.Empty(t, list) }) - k.GetBallot(ctx, identifier) + t.Run("should return empty if maturity blocks greater than height", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + identifier := sample.ZetaIndex(t) + b := &types.Ballot{ + Index: "", + BallotIdentifier: identifier, + VoterList: nil, + ObservationType: 0, + BallotThreshold: sdk.Dec{}, + BallotStatus: 0, + BallotCreationHeight: 1, + } + list := k.GetMaturedBallotList(ctx) + require.Empty(t, list) + ctx = ctx.WithBlockHeight(1) + k.AddBallotToList(ctx, *b) + list = k.GetMaturedBallotList(ctx) + require.Empty(t, list) + }) } diff --git a/x/observer/keeper/blame_test.go b/x/observer/keeper/blame_test.go new file mode 100644 index 0000000000..420e6f7d8f --- /dev/null +++ b/x/observer/keeper/blame_test.go @@ -0,0 +1,110 @@ +package keeper_test + +import ( + "sort" + "testing" + + "github.com/cosmos/cosmos-sdk/types/query" + "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/testutil/sample" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestKeeper_GetBlame(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + var chainId int64 = 97 + var nonce uint64 = 101 + digest := sample.ZetaIndex(t) + + index := types.GetBlameIndex(chainId, nonce, digest, 123) + + k.SetBlame(ctx, types.Blame{ + Index: index, + FailureReason: "failed to join party", + Nodes: nil, + }) + + blameRecords, found := k.GetBlame(ctx, index) + require.True(t, found) + require.Equal(t, index, blameRecords.Index) +} + +func TestKeeper_GetBlameByChainAndNonce(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + var chainId int64 = 97 + var nonce uint64 = 101 + digest := sample.ZetaIndex(t) + + index := types.GetBlameIndex(chainId, nonce, digest, 123) + + k.SetBlame(ctx, types.Blame{ + Index: index, + FailureReason: "failed to join party", + Nodes: nil, + }) + + blameRecords, found := k.GetBlamesByChainAndNonce(ctx, chainId, int64(nonce)) + require.True(t, found) + require.Equal(t, 1, len(blameRecords)) + require.Equal(t, index, blameRecords[0].Index) +} + +func TestKeeper_BlameAll(t *testing.T) { + t.Run("GetBlameRecord by limit ", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + blameList := sample.BlameRecordsList(t, 10) + for _, record := range blameList { + k.SetBlame(ctx, record) + } + sort.Slice(blameList, func(i, j int) bool { + return blameList[i].Index < blameList[j].Index + }) + rst, pageRes, err := k.GetAllBlamePaginated(ctx, &query.PageRequest{Limit: 10, CountTotal: true}) + require.NoError(t, err) + sort.Slice(rst, func(i, j int) bool { + return rst[i].Index < rst[j].Index + }) + require.Equal(t, blameList, rst) + require.Equal(t, len(blameList), int(pageRes.Total)) + }) + t.Run("GetBlameRecord by offset ", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + blameList := sample.BlameRecordsList(t, 20) + offset := 10 + for _, record := range blameList { + k.SetBlame(ctx, record) + } + sort.Slice(blameList, func(i, j int) bool { + return blameList[i].Index < blameList[j].Index + }) + rst, pageRes, err := k.GetAllBlamePaginated(ctx, &query.PageRequest{Offset: uint64(offset), CountTotal: true}) + require.NoError(t, err) + sort.Slice(rst, func(i, j int) bool { + return rst[i].Index < rst[j].Index + }) + require.Subset(t, blameList, rst) + require.Equal(t, len(blameList)-offset, len(rst)) + require.Equal(t, len(blameList), int(pageRes.Total)) + }) + t.Run("GetAllBlameRecord", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + blameList := sample.BlameRecordsList(t, 100) + for _, record := range blameList { + k.SetBlame(ctx, record) + } + rst := k.GetAllBlame(ctx) + sort.Slice(rst, func(i, j int) bool { + return rst[i].Index < rst[j].Index + }) + sort.Slice(blameList, func(i, j int) bool { + return blameList[i].Index < blameList[j].Index + }) + require.Equal(t, blameList, rst) + }) + t.Run("Get no records if nothing is set", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + rst := k.GetAllBlame(ctx) + require.Len(t, rst, 0) + }) +} diff --git a/x/observer/keeper/block_header_test.go b/x/observer/keeper/block_header_test.go new file mode 100644 index 0000000000..5ab61ed20f --- /dev/null +++ b/x/observer/keeper/block_header_test.go @@ -0,0 +1,32 @@ +package keeper_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/pkg/proofs" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/testutil/sample" +) + +func TestKeeper_GetBlockHeader(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + blockHash := sample.Hash().Bytes() + _, found := k.GetBlockHeader(ctx, blockHash) + require.False(t, found) + + bh := proofs.BlockHeader{ + Height: 1, + Hash: blockHash, + ParentHash: sample.Hash().Bytes(), + ChainId: 1, + Header: proofs.HeaderData{}, + } + k.SetBlockHeader(ctx, bh) + _, found = k.GetBlockHeader(ctx, blockHash) + require.True(t, found) + + k.RemoveBlockHeader(ctx, blockHash) + _, found = k.GetBlockHeader(ctx, blockHash) + require.False(t, found) +} diff --git a/x/observer/keeper/chain_params_test.go b/x/observer/keeper/chain_params_test.go index 3464c950fa..84f28e5cff 100644 --- a/x/observer/keeper/chain_params_test.go +++ b/x/observer/keeper/chain_params_test.go @@ -23,7 +23,6 @@ func TestKeeper_GetSupportedChainFromChainID(t *testing.T) { // chain params list but chain not supported chainParams := sample.ChainParams(getValidEthChainIDWithIndex(t, 0)) - chainParams.IsSupported = false k.SetChainParamsList(ctx, types.ChainParamsList{ ChainParams: []*types.ChainParams{chainParams}, }) @@ -40,6 +39,35 @@ func TestKeeper_GetSupportedChainFromChainID(t *testing.T) { }) } +func TestKeeper_GetChainParamsByChainID(t *testing.T) { + t.Run("return false if chain params not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + _, found := k.GetChainParamsByChainID(ctx, getValidEthChainIDWithIndex(t, 0)) + require.False(t, found) + }) + + t.Run("return true if found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + chainParams := sample.ChainParams(getValidEthChainIDWithIndex(t, 0)) + k.SetChainParamsList(ctx, types.ChainParamsList{ + ChainParams: []*types.ChainParams{chainParams}, + }) + res, found := k.GetChainParamsByChainID(ctx, getValidEthChainIDWithIndex(t, 0)) + require.True(t, found) + require.Equal(t, chainParams, res) + }) + + t.Run("return false if chain not found in params", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + chainParams := sample.ChainParams(getValidEthChainIDWithIndex(t, 0)) + k.SetChainParamsList(ctx, types.ChainParamsList{ + ChainParams: []*types.ChainParams{chainParams}, + }) + _, found := k.GetChainParamsByChainID(ctx, getValidEthChainIDWithIndex(t, 1)) + require.False(t, found) + }) +} func TestKeeper_GetSupportedChains(t *testing.T) { t.Run("return empty list if no core params list", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) diff --git a/x/observer/keeper/crosschain_flags_test.go b/x/observer/keeper/crosschain_flags_test.go new file mode 100644 index 0000000000..4d3f003b92 --- /dev/null +++ b/x/observer/keeper/crosschain_flags_test.go @@ -0,0 +1,83 @@ +package keeper_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestKeeper_IsInboundEnabled(t *testing.T) { + t.Run("should return false if flags not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + enabled := k.IsInboundEnabled(ctx) + require.False(t, enabled) + }) + + t.Run("should return if flags found and set", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + k.SetCrosschainFlags(ctx, types.CrosschainFlags{ + IsInboundEnabled: false, + }) + enabled := k.IsInboundEnabled(ctx) + require.False(t, enabled) + + k.SetCrosschainFlags(ctx, types.CrosschainFlags{ + IsInboundEnabled: true, + }) + + enabled = k.IsInboundEnabled(ctx) + require.True(t, enabled) + }) +} + +func TestKeeper_IsOutboundEnabled(t *testing.T) { + t.Run("should return false if flags not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + enabled := k.IsOutboundEnabled(ctx) + require.False(t, enabled) + }) + + t.Run("should return if flags found and set", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + k.SetCrosschainFlags(ctx, types.CrosschainFlags{ + IsOutboundEnabled: false, + }) + enabled := k.IsOutboundEnabled(ctx) + require.False(t, enabled) + + k.SetCrosschainFlags(ctx, types.CrosschainFlags{ + IsOutboundEnabled: true, + }) + + enabled = k.IsOutboundEnabled(ctx) + require.True(t, enabled) + }) +} + +func TestKeeper_DisableInboundOnly(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + k.DisableInboundOnly(ctx) + enabled := k.IsOutboundEnabled(ctx) + require.True(t, enabled) + + enabled = k.IsInboundEnabled(ctx) + require.False(t, enabled) + + k.SetCrosschainFlags(ctx, types.CrosschainFlags{ + IsOutboundEnabled: false, + }) + + k.DisableInboundOnly(ctx) + enabled = k.IsOutboundEnabled(ctx) + require.False(t, enabled) + + enabled = k.IsInboundEnabled(ctx) + require.False(t, enabled) +} diff --git a/x/observer/keeper/grpc_query_ballot_test.go b/x/observer/keeper/grpc_query_ballot_test.go new file mode 100644 index 0000000000..32bbe6265d --- /dev/null +++ b/x/observer/keeper/grpc_query_ballot_test.go @@ -0,0 +1,136 @@ +package keeper_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/testutil/sample" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestKeeper_HasVoted(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.HasVoted(wctx, nil) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should return false if ballot not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.HasVoted(wctx, &types.QueryHasVotedRequest{ + BallotIdentifier: "test", + }) + require.NoError(t, err) + require.Equal(t, &types.QueryHasVotedResponse{ + HasVoted: false, + }, res) + }) + + t.Run("should return true if ballot found and voted", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + voter := sample.AccAddress() + ballot := types.Ballot{ + Index: "index", + BallotIdentifier: "index", + VoterList: []string{voter}, + Votes: []types.VoteType{types.VoteType_SuccessObservation}, + BallotStatus: types.BallotStatus_BallotInProgress, + } + k.SetBallot(ctx, &ballot) + + res, err := k.HasVoted(wctx, &types.QueryHasVotedRequest{ + BallotIdentifier: "index", + VoterAddress: voter, + }) + require.NoError(t, err) + require.Equal(t, &types.QueryHasVotedResponse{ + HasVoted: true, + }, res) + }) + + t.Run("should return false if ballot found and not voted", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + voter := sample.AccAddress() + ballot := types.Ballot{ + Index: "index", + BallotIdentifier: "index", + VoterList: []string{voter}, + Votes: []types.VoteType{types.VoteType_SuccessObservation}, + BallotStatus: types.BallotStatus_BallotInProgress, + } + k.SetBallot(ctx, &ballot) + + res, err := k.HasVoted(wctx, &types.QueryHasVotedRequest{ + BallotIdentifier: "index", + VoterAddress: sample.AccAddress(), + }) + require.NoError(t, err) + require.Equal(t, &types.QueryHasVotedResponse{ + HasVoted: false, + }, res) + }) +} + +func TestKeeper_BallotByIdentifier(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.BallotByIdentifier(wctx, nil) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should return nil if ballot not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.BallotByIdentifier(wctx, &types.QueryBallotByIdentifierRequest{ + BallotIdentifier: "test", + }) + require.Error(t, err) + require.Nil(t, res) + }) + + t.Run("should return ballot if exists", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + voter := sample.AccAddress() + ballot := types.Ballot{ + Index: "index", + BallotIdentifier: "index", + VoterList: []string{voter}, + Votes: []types.VoteType{types.VoteType_SuccessObservation}, + BallotStatus: types.BallotStatus_BallotInProgress, + } + k.SetBallot(ctx, &ballot) + + res, err := k.BallotByIdentifier(wctx, &types.QueryBallotByIdentifierRequest{ + BallotIdentifier: "index", + }) + require.NoError(t, err) + require.Equal(t, &types.QueryBallotByIdentifierResponse{ + BallotIdentifier: ballot.BallotIdentifier, + Voters: []*types.VoterList{ + { + VoterAddress: voter, + VoteType: types.VoteType_SuccessObservation, + }, + }, + ObservationType: ballot.ObservationType, + BallotStatus: ballot.BallotStatus, + }, res) + }) +} diff --git a/x/observer/keeper/grpc_query_blame_test.go b/x/observer/keeper/grpc_query_blame_test.go index 36d95ed23e..728d1748d0 100644 --- a/x/observer/keeper/grpc_query_blame_test.go +++ b/x/observer/keeper/grpc_query_blame_test.go @@ -1,10 +1,9 @@ package keeper_test import ( - "sort" "testing" - "github.com/cosmos/cosmos-sdk/types/query" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/require" keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/testutil/sample" @@ -12,99 +11,126 @@ import ( ) func TestKeeper_BlameByIdentifier(t *testing.T) { - k, ctx, _, _ := keepertest.ObserverKeeper(t) - var chainId int64 = 97 - var nonce uint64 = 101 - digest := "85f5e10431f69bc2a14046a13aabaefc660103b6de7a84f75c4b96181d03f0b5" - - index := types.GetBlameIndex(chainId, nonce, digest, 123) + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) - k.SetBlame(ctx, types.Blame{ - Index: index, - FailureReason: "failed to join party", - Nodes: nil, + res, err := k.BlameByIdentifier(wctx, nil) + require.Nil(t, res) + require.Error(t, err) }) - blameRecords, found := k.GetBlame(ctx, index) - require.True(t, found) - require.Equal(t, index, blameRecords.Index) -} - -func TestKeeper_BlameByChainAndNonce(t *testing.T) { - k, ctx, _, _ := keepertest.ObserverKeeper(t) - var chainId int64 = 97 - var nonce uint64 = 101 - digest := "85f5e10431f69bc2a14046a13aabaefc660103b6de7a84f75c4b96181d03f0b5" - - index := types.GetBlameIndex(chainId, nonce, digest, 123) + t.Run("should error if blame not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) - k.SetBlame(ctx, types.Blame{ - Index: index, - FailureReason: "failed to join party", - Nodes: nil, + res, err := k.BlameByIdentifier(wctx, &types.QueryBlameByIdentifierRequest{ + BlameIdentifier: "test", + }) + require.Error(t, err) + require.Nil(t, res) }) - blameRecords, found := k.GetBlamesByChainAndNonce(ctx, chainId, int64(nonce)) - require.True(t, found) - require.Equal(t, 1, len(blameRecords)) - require.Equal(t, index, blameRecords[0].Index) -} - -func TestKeeper_BlameAll(t *testing.T) { - t.Run("GetBlameRecord by limit ", func(t *testing.T) { + t.Run("should return blame info if found", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) - blameList := sample.BlameRecordsList(t, 10) - for _, record := range blameList { - k.SetBlame(ctx, record) + wctx := sdk.WrapSDKContext(ctx) + var chainId int64 = 97 + var nonce uint64 = 101 + digest := sample.ZetaIndex(t) + + index := types.GetBlameIndex(chainId, nonce, digest, 123) + blame := types.Blame{ + Index: index, + FailureReason: "failed to join party", + Nodes: nil, } - sort.Slice(blameList, func(i, j int) bool { - return blameList[i].Index < blameList[j].Index + k.SetBlame(ctx, blame) + + res, err := k.BlameByIdentifier(wctx, &types.QueryBlameByIdentifierRequest{ + BlameIdentifier: index, }) - rst, pageRes, err := k.GetAllBlamePaginated(ctx, &query.PageRequest{Limit: 10, CountTotal: true}) require.NoError(t, err) - sort.Slice(rst, func(i, j int) bool { - return rst[i].Index < rst[j].Index - }) - require.Equal(t, blameList, rst) - require.Equal(t, len(blameList), int(pageRes.Total)) + require.Equal(t, &types.QueryBlameByIdentifierResponse{ + BlameInfo: &blame, + }, res) }) - t.Run("GetBlameRecord by offset ", func(t *testing.T) { +} + +func TestKeeper_GetAllBlameRecords(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) - blameList := sample.BlameRecordsList(t, 20) - offset := 10 - for _, record := range blameList { - k.SetBlame(ctx, record) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.GetAllBlameRecords(wctx, nil) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should return all if found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + var chainId int64 = 97 + var nonce uint64 = 101 + digest := sample.ZetaIndex(t) + + index := types.GetBlameIndex(chainId, nonce, digest, 123) + blame := types.Blame{ + Index: index, + FailureReason: "failed to join party", + Nodes: nil, } - sort.Slice(blameList, func(i, j int) bool { - return blameList[i].Index < blameList[j].Index - }) - rst, pageRes, err := k.GetAllBlamePaginated(ctx, &query.PageRequest{Offset: uint64(offset), CountTotal: true}) + k.SetBlame(ctx, blame) + + res, err := k.GetAllBlameRecords(wctx, &types.QueryAllBlameRecordsRequest{}) require.NoError(t, err) - sort.Slice(rst, func(i, j int) bool { - return rst[i].Index < rst[j].Index - }) - require.Subset(t, blameList, rst) - require.Equal(t, len(blameList)-offset, len(rst)) - require.Equal(t, len(blameList), int(pageRes.Total)) + require.Equal(t, []types.Blame{blame}, res.BlameInfo) }) - t.Run("GetAllBlameRecord", func(t *testing.T) { +} + +func TestKeeper_BlamesByChainAndNonce(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) - blameList := sample.BlameRecordsList(t, 100) - for _, record := range blameList { - k.SetBlame(ctx, record) - } - rst := k.GetAllBlame(ctx) - sort.Slice(rst, func(i, j int) bool { - return rst[i].Index < rst[j].Index - }) - sort.Slice(blameList, func(i, j int) bool { - return blameList[i].Index < blameList[j].Index + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.BlamesByChainAndNonce(wctx, nil) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should error if blame not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.BlamesByChainAndNonce(wctx, &types.QueryBlameByChainAndNonceRequest{ + ChainId: 1, + Nonce: 1, }) - require.Equal(t, blameList, rst) + require.Error(t, err) + require.Nil(t, res) }) - t.Run("Get no records if nothing is set", func(t *testing.T) { + + t.Run("should return blame info if found", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) - rst := k.GetAllBlame(ctx) - require.Len(t, rst, 0) + wctx := sdk.WrapSDKContext(ctx) + var chainId int64 = 97 + var nonce uint64 = 101 + digest := sample.ZetaIndex(t) + + index := types.GetBlameIndex(chainId, nonce, digest, 123) + blame := types.Blame{ + Index: index, + FailureReason: "failed to join party", + Nodes: nil, + } + k.SetBlame(ctx, blame) + + res, err := k.BlamesByChainAndNonce(wctx, &types.QueryBlameByChainAndNonceRequest{ + ChainId: chainId, + Nonce: int64(nonce), + }) + require.NoError(t, err) + require.Equal(t, &types.QueryBlameByChainAndNonceResponse{ + BlameInfo: []*types.Blame{&blame}, + }, res) }) } diff --git a/x/observer/keeper/grpc_query_block_header_test.go b/x/observer/keeper/grpc_query_block_header_test.go new file mode 100644 index 0000000000..9cf50f07a4 --- /dev/null +++ b/x/observer/keeper/grpc_query_block_header_test.go @@ -0,0 +1,119 @@ +package keeper_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/pkg/proofs" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/testutil/sample" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestKeeper_GetAllBlockHeaders(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.GetAllBlockHeaders(wctx, nil) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should return if block header is found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + bh := proofs.BlockHeader{ + Height: 1, + Hash: sample.Hash().Bytes(), + ParentHash: sample.Hash().Bytes(), + ChainId: 1, + Header: proofs.HeaderData{}, + } + k.SetBlockHeader(ctx, bh) + + res, err := k.GetAllBlockHeaders(wctx, &types.QueryAllBlockHeaderRequest{}) + require.NoError(t, err) + require.Equal(t, &bh, res.BlockHeaders[0]) + }) +} + +func TestKeeper_GetBlockHeaderByHash(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.GetBlockHeaderByHash(wctx, nil) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should error if not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.GetBlockHeaderByHash(wctx, &types.QueryGetBlockHeaderByHashRequest{ + BlockHash: sample.Hash().Bytes(), + }) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should return if block header is found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + bh := proofs.BlockHeader{ + Height: 1, + Hash: sample.Hash().Bytes(), + ParentHash: sample.Hash().Bytes(), + ChainId: 1, + Header: proofs.HeaderData{}, + } + k.SetBlockHeader(ctx, bh) + + res, err := k.GetBlockHeaderByHash(wctx, &types.QueryGetBlockHeaderByHashRequest{ + BlockHash: bh.Hash, + }) + require.NoError(t, err) + require.Equal(t, &bh, res.BlockHeader) + }) +} + +func TestKeeper_GetBlockHeaderStateByChain(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.GetBlockHeaderStateByChain(wctx, nil) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should error if not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.GetBlockHeaderStateByChain(wctx, &types.QueryGetBlockHeaderStateRequest{ + ChainId: 1, + }) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should return if block header state is found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + bhs := types.BlockHeaderState{ + ChainId: 1, + } + k.SetBlockHeaderState(ctx, bhs) + + res, err := k.GetBlockHeaderStateByChain(wctx, &types.QueryGetBlockHeaderStateRequest{ + ChainId: 1, + }) + require.NoError(t, err) + require.Equal(t, &bhs, res.BlockHeaderState) + }) +} diff --git a/x/observer/keeper/grpc_query_chain_params_test.go b/x/observer/keeper/grpc_query_chain_params_test.go new file mode 100644 index 0000000000..0e7145e5cd --- /dev/null +++ b/x/observer/keeper/grpc_query_chain_params_test.go @@ -0,0 +1,97 @@ +package keeper_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/pkg/chains" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestKeeper_GetChainParamsForChain(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.GetChainParamsForChain(wctx, nil) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should error if chain params not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.GetChainParamsForChain(wctx, &types.QueryGetChainParamsForChainRequest{ + ChainId: 987, + }) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should return if chain params found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + list := types.ChainParamsList{ + ChainParams: []*types.ChainParams{ + { + ChainId: chains.ZetaPrivnetChain().ChainId, + IsSupported: false, + }, + }, + } + k.SetChainParamsList(ctx, list) + + res, err := k.GetChainParamsForChain(wctx, &types.QueryGetChainParamsForChainRequest{ + ChainId: chains.ZetaPrivnetChain().ChainId, + }) + require.NoError(t, err) + require.Equal(t, &types.QueryGetChainParamsForChainResponse{ + ChainParams: list.ChainParams[0], + }, res) + }) +} + +func TestKeeper_GetChainParams(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.GetChainParams(wctx, nil) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should error if chain params not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.GetChainParams(wctx, &types.QueryGetChainParamsRequest{}) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should return if chain params found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + list := types.ChainParamsList{ + ChainParams: []*types.ChainParams{ + { + ChainId: chains.ZetaPrivnetChain().ChainId, + IsSupported: false, + }, + }, + } + k.SetChainParamsList(ctx, list) + + res, err := k.GetChainParams(wctx, &types.QueryGetChainParamsRequest{}) + require.NoError(t, err) + require.Equal(t, &types.QueryGetChainParamsResponse{ + ChainParams: &list, + }, res) + }) +} diff --git a/x/observer/keeper/grpc_query_crosschain_flags_test.go b/x/observer/keeper/grpc_query_crosschain_flags_test.go new file mode 100644 index 0000000000..172df715d9 --- /dev/null +++ b/x/observer/keeper/grpc_query_crosschain_flags_test.go @@ -0,0 +1,47 @@ +package keeper_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestKeeper_CrosschainFlags(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.CrosschainFlags(wctx, nil) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should error if crosschain flags not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.CrosschainFlags(wctx, &types.QueryGetCrosschainFlagsRequest{}) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should return if crosschain flags found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + flags := types.CrosschainFlags{ + IsInboundEnabled: false, + } + k.SetCrosschainFlags(ctx, flags) + + res, err := k.CrosschainFlags(wctx, &types.QueryGetCrosschainFlagsRequest{}) + + require.NoError(t, err) + require.Equal(t, &types.QueryGetCrosschainFlagsResponse{ + CrosschainFlags: flags, + }, res) + }) +} diff --git a/x/observer/keeper/grpc_query_keygen_test.go b/x/observer/keeper/grpc_query_keygen_test.go index f4c61aeabd..f081382452 100644 --- a/x/observer/keeper/grpc_query_keygen_test.go +++ b/x/observer/keeper/grpc_query_keygen_test.go @@ -1,4 +1,4 @@ -package keeper +package keeper_test import ( "github.com/stretchr/testify/require" @@ -6,33 +6,32 @@ import ( "testing" sdk "github.com/cosmos/cosmos-sdk/types" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/x/observer/types" ) -func TestKeygenQuery(t *testing.T) { - keeper, ctx := SetupKeeper(t) - wctx := sdk.WrapSDKContext(ctx) - item := createTestKeygen(keeper, ctx) - for _, tc := range []struct { - desc string - request *types.QueryGetKeygenRequest - response *types.QueryGetKeygenResponse - err error - }{ - { - desc: "First", - request: &types.QueryGetKeygenRequest{}, - response: &types.QueryGetKeygenResponse{Keygen: &item}, - }, - } { - tc := tc - t.Run(tc.desc, func(t *testing.T) { - response, err := keeper.Keygen(wctx, tc.request) - if tc.err != nil { - require.ErrorIs(t, err, tc.err) - } else { - require.Equal(t, tc.response, response) - } - }) - } +func TestKeeper_Keygen(t *testing.T) { + t.Run("should error if keygen not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.Keygen(wctx, &types.QueryGetKeygenRequest{}) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should return if keygen found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + keygen := types.Keygen{ + BlockNumber: 10, + } + k.SetKeygen(ctx, keygen) + + res, err := k.Keygen(wctx, &types.QueryGetKeygenRequest{}) + require.NoError(t, err) + require.Equal(t, &types.QueryGetKeygenResponse{ + Keygen: &keygen, + }, res) + }) } diff --git a/x/observer/keeper/grpc_query_node_account_test.go b/x/observer/keeper/grpc_query_node_account_test.go index 03facb08b4..8645357a0c 100644 --- a/x/observer/keeper/grpc_query_node_account_test.go +++ b/x/observer/keeper/grpc_query_node_account_test.go @@ -1,4 +1,4 @@ -package keeper +package keeper_test import ( "testing" @@ -6,15 +6,17 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/query" "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/x/observer/types" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) func TestNodeAccountQuerySingle(t *testing.T) { - keeper, ctx := SetupKeeper(t) + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) - msgs := createNNodeAccount(keeper, ctx, 2) + msgs := createNNodeAccount(k, ctx, 2) for _, tc := range []struct { desc string request *types.QueryGetNodeAccountRequest @@ -43,7 +45,7 @@ func TestNodeAccountQuerySingle(t *testing.T) { } { tc := tc t.Run(tc.desc, func(t *testing.T) { - response, err := keeper.NodeAccount(wctx, tc.request) + response, err := k.NodeAccount(wctx, tc.request) if tc.err != nil { require.ErrorIs(t, err, tc.err) } else { @@ -54,9 +56,9 @@ func TestNodeAccountQuerySingle(t *testing.T) { } func TestNodeAccountQueryPaginated(t *testing.T) { - keeper, ctx := SetupKeeper(t) + k, ctx, _, _ := keepertest.ObserverKeeper(t) wctx := sdk.WrapSDKContext(ctx) - msgs := createNNodeAccount(keeper, ctx, 5) + msgs := createNNodeAccount(k, ctx, 5) request := func(next []byte, offset, limit uint64, total bool) *types.QueryAllNodeAccountRequest { return &types.QueryAllNodeAccountRequest{ @@ -71,7 +73,7 @@ func TestNodeAccountQueryPaginated(t *testing.T) { t.Run("ByOffset", func(t *testing.T) { step := 2 for i := 0; i < len(msgs); i += step { - resp, err := keeper.NodeAccountAll(wctx, request(nil, uint64(i), uint64(step), false)) + resp, err := k.NodeAccountAll(wctx, request(nil, uint64(i), uint64(step), false)) require.NoError(t, err) for j := i; j < len(msgs) && j < i+step; j++ { require.Equal(t, &msgs[j], resp.NodeAccount[j-i]) @@ -82,7 +84,7 @@ func TestNodeAccountQueryPaginated(t *testing.T) { step := 2 var next []byte for i := 0; i < len(msgs); i += step { - resp, err := keeper.NodeAccountAll(wctx, request(next, 0, uint64(step), false)) + resp, err := k.NodeAccountAll(wctx, request(next, 0, uint64(step), false)) require.NoError(t, err) for j := i; j < len(msgs) && j < i+step; j++ { require.Equal(t, &msgs[j], resp.NodeAccount[j-i]) @@ -91,12 +93,12 @@ func TestNodeAccountQueryPaginated(t *testing.T) { } }) t.Run("Total", func(t *testing.T) { - resp, err := keeper.NodeAccountAll(wctx, request(nil, 0, 0, true)) + resp, err := k.NodeAccountAll(wctx, request(nil, 0, 0, true)) require.NoError(t, err) require.Equal(t, len(msgs), int(resp.Pagination.Total)) }) t.Run("InvalidRequest", func(t *testing.T) { - _, err := keeper.NodeAccountAll(wctx, nil) + _, err := k.NodeAccountAll(wctx, nil) require.ErrorIs(t, err, status.Error(codes.InvalidArgument, "invalid request")) }) } diff --git a/x/observer/keeper/grpc_query_nonces_test.go b/x/observer/keeper/grpc_query_nonces_test.go index 657eb59008..5a66eebfbf 100644 --- a/x/observer/keeper/grpc_query_nonces_test.go +++ b/x/observer/keeper/grpc_query_nonces_test.go @@ -108,3 +108,84 @@ func TestChainNoncesQueryPaginated(t *testing.T) { require.ErrorIs(t, err, status.Error(codes.InvalidArgument, "invalid request")) }) } + +func TestPendingNoncesQuerySingle(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.PendingNoncesByChain(wctx, nil) + require.Error(t, err) + require.Nil(t, res) + }) + + t.Run("should error if tss not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.PendingNoncesByChain(wctx, &types.QueryPendingNoncesByChainRequest{ + ChainId: 1, + }) + require.Error(t, err) + require.Nil(t, res) + }) + + t.Run("should error if not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + tss := sample.Tss() + k.SetTSS(ctx, tss) + res, err := k.PendingNoncesByChain(wctx, &types.QueryPendingNoncesByChainRequest{ + ChainId: 1, + }) + require.Error(t, err) + require.Nil(t, res) + }) + + t.Run("should return if found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + tss := sample.Tss() + k.SetTSS(ctx, tss) + pendingNonces := sample.PendingNoncesList(t, "sample", 5) + pendingNonces[1].Tss = tss.TssPubkey + for _, nonce := range pendingNonces { + k.SetPendingNonces(ctx, nonce) + } + res, err := k.PendingNoncesByChain(wctx, &types.QueryPendingNoncesByChainRequest{ + ChainId: 1, + }) + require.NoError(t, err) + require.Equal(t, pendingNonces[1], res.PendingNonces) + }) +} + +func TestPendingNoncesQueryPaginated(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + pendingNonces := sample.PendingNoncesList(t, "sample", 5) + for _, nonce := range pendingNonces { + k.SetPendingNonces(ctx, nonce) + } + + request := func(next []byte, offset, limit uint64, total bool) *types.QueryAllPendingNoncesRequest { + return &types.QueryAllPendingNoncesRequest{ + Pagination: &query.PageRequest{ + Key: next, + Offset: offset, + Limit: limit, + CountTotal: total, + }, + } + } + + t.Run("Total", func(t *testing.T) { + resp, err := k.PendingNoncesAll(wctx, request(nil, 0, 0, true)) + require.NoError(t, err) + require.Equal(t, len(pendingNonces), int(resp.Pagination.Total)) + }) + t.Run("InvalidRequest", func(t *testing.T) { + _, err := k.PendingNoncesAll(wctx, nil) + require.ErrorIs(t, err, status.Error(codes.InvalidArgument, "invalid request")) + }) +} diff --git a/x/observer/keeper/grpc_query_observer_test.go b/x/observer/keeper/grpc_query_observer_test.go new file mode 100644 index 0000000000..73675fae0d --- /dev/null +++ b/x/observer/keeper/grpc_query_observer_test.go @@ -0,0 +1,78 @@ +package keeper_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/testutil/sample" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestKeeper_ShowObserverCount(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.ShowObserverCount(wctx, nil) + require.Error(t, err) + require.Nil(t, res) + }) + + t.Run("should error if not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.ShowObserverCount(wctx, &types.QueryShowObserverCountRequest{}) + require.Error(t, err) + require.Nil(t, res) + }) + + t.Run("should return if found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + count := 1 + loc := &types.LastObserverCount{ + Count: uint64(count), + } + k.SetLastObserverCount(ctx, loc) + + res, err := k.ShowObserverCount(wctx, &types.QueryShowObserverCountRequest{}) + require.NoError(t, err) + require.Equal(t, loc, res.LastObserverCount) + }) +} + +func TestKeeper_ObserverSet(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.ObserverSet(wctx, nil) + require.Error(t, err) + require.Nil(t, res) + }) + + t.Run("should error if not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.ObserverSet(wctx, &types.QueryObserverSet{}) + require.Error(t, err) + require.Nil(t, res) + }) + + t.Run("should return if found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + os := sample.ObserverSet(10) + k.SetObserverSet(ctx, os) + + res, err := k.ObserverSet(wctx, &types.QueryObserverSet{}) + require.NoError(t, err) + require.Equal(t, os.ObserverList, res.Observers) + }) +} diff --git a/x/observer/keeper/grpc_query_params_test.go b/x/observer/keeper/grpc_query_params_test.go index 4cd534fa59..a1de6c63b2 100644 --- a/x/observer/keeper/grpc_query_params_test.go +++ b/x/observer/keeper/grpc_query_params_test.go @@ -1,20 +1,21 @@ -package keeper +package keeper_test import ( "testing" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/x/observer/types" ) func TestParamsQuery(t *testing.T) { - keeper, ctx := SetupKeeper(t) + k, ctx, _, _ := keepertest.ObserverKeeper(t) wctx := sdk.WrapSDKContext(ctx) params := types.DefaultParams() - keeper.SetParams(ctx, params) + k.SetParams(ctx, params) - response, err := keeper.Params(wctx, &types.QueryParamsRequest{}) + response, err := k.Params(wctx, &types.QueryParamsRequest{}) require.NoError(t, err) require.Equal(t, &types.QueryParamsResponse{Params: params}, response) } diff --git a/x/observer/keeper/grpc_query_prove_test.go b/x/observer/keeper/grpc_query_prove_test.go new file mode 100644 index 0000000000..7eb11b00f5 --- /dev/null +++ b/x/observer/keeper/grpc_query_prove_test.go @@ -0,0 +1,72 @@ +package keeper_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/pkg/proofs" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/testutil/sample" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestKeeper_Prove(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.Prove(wctx, nil) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should error if invalid hash", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.Prove(wctx, &types.QueryProveRequest{ + ChainId: 987, + BlockHash: sample.Hash().String(), + }) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should error if header not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.Prove(wctx, &types.QueryProveRequest{ + ChainId: 5, + BlockHash: sample.Hash().String(), + }) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should error if proof not valid", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + hash := sample.Hash() + bh := proofs.BlockHeader{ + Height: 1, + Hash: hash.Bytes(), + ParentHash: sample.Hash().Bytes(), + ChainId: 1, + Header: proofs.HeaderData{}, + } + k.SetBlockHeader(ctx, bh) + + res, err := k.Prove(wctx, &types.QueryProveRequest{ + ChainId: 5, + BlockHash: hash.String(), + Proof: &proofs.Proof{}, + }) + require.Nil(t, res) + require.Error(t, err) + }) + + // TODO: // https://github.com/zeta-chain/node/issues/1875 add more tests +} diff --git a/x/observer/keeper/grpc_query_supported_chain_test.go b/x/observer/keeper/grpc_query_supported_chain_test.go new file mode 100644 index 0000000000..50acd5703a --- /dev/null +++ b/x/observer/keeper/grpc_query_supported_chain_test.go @@ -0,0 +1,21 @@ +package keeper_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/pkg/chains" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" +) + +func TestKeeper_SupportedChains(t *testing.T) { + t.Run("should return supported chains", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.SupportedChains(wctx, nil) + require.NoError(t, err) + require.Equal(t, []*chains.Chain{}, res.Chains) + }) +} diff --git a/x/observer/keeper/grpc_query_tss_test.go b/x/observer/keeper/grpc_query_tss_test.go new file mode 100644 index 0000000000..105d7c1cf9 --- /dev/null +++ b/x/observer/keeper/grpc_query_tss_test.go @@ -0,0 +1,230 @@ +package keeper_test + +import ( + "math/rand" + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/pkg/chains" + "github.com/zeta-chain/zetacore/pkg/crypto" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/testutil/sample" + "github.com/zeta-chain/zetacore/x/observer/types" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestTSSQuerySingle(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + tss := sample.Tss() + wctx := sdk.WrapSDKContext(ctx) + + for _, tc := range []struct { + desc string + request *types.QueryGetTSSRequest + response *types.QueryGetTSSResponse + skipSettingTss bool + err error + }{ + { + desc: "Skip setting tss", + request: &types.QueryGetTSSRequest{}, + skipSettingTss: true, + err: status.Error(codes.InvalidArgument, "not found"), + }, + { + desc: "InvalidRequest", + err: status.Error(codes.InvalidArgument, "invalid request"), + }, + { + desc: "Should return tss", + request: &types.QueryGetTSSRequest{}, + response: &types.QueryGetTSSResponse{TSS: tss}, + }, + } { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + if !tc.skipSettingTss { + k.SetTSS(ctx, tss) + } + response, err := k.TSS(wctx, tc.request) + if tc.err != nil { + require.ErrorIs(t, err, tc.err) + } else { + require.Equal(t, tc.response, response) + } + }) + } +} + +func TestTSSQueryHistory(t *testing.T) { + keeper, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + for _, tc := range []struct { + desc string + tssCount int + foundPrevious bool + err error + }{ + { + desc: "1 Tss addresses", + tssCount: 1, + foundPrevious: false, + err: nil, + }, + { + desc: "10 Tss addresses", + tssCount: 10, + foundPrevious: true, + err: nil, + }, + } { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + tssList := sample.TssList(tc.tssCount) + for _, tss := range tssList { + keeper.SetTSS(ctx, tss) + keeper.SetTSSHistory(ctx, tss) + } + request := &types.QueryTssHistoryRequest{} + response, err := keeper.TssHistory(wctx, request) + if tc.err != nil { + require.ErrorIs(t, err, tc.err) + } else { + require.Equal(t, len(tssList), len(response.TssList)) + prevTss, found := keeper.GetPreviousTSS(ctx) + require.Equal(t, tc.foundPrevious, found) + if found { + require.Equal(t, tssList[len(tssList)-2], prevTss) + } + } + }) + } +} + +func TestKeeper_GetTssAddress(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.GetTssAddress(wctx, nil) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should error if not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.GetTssAddress(wctx, &types.QueryGetTssAddressRequest{ + BitcoinChainId: 1, + }) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should error if invalid chain id", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + tss := sample.Tss() + k.SetTSS(ctx, tss) + + res, err := k.GetTssAddress(wctx, &types.QueryGetTssAddressRequest{ + BitcoinChainId: 987, + }) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should return if valid chain id", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + tss := sample.Tss() + k.SetTSS(ctx, tss) + + res, err := k.GetTssAddress(wctx, &types.QueryGetTssAddressRequest{ + BitcoinChainId: chains.BtcRegtestChain().ChainId, + }) + require.NoError(t, err) + expectedBitcoinParams, err := chains.BitcoinNetParamsFromChainID(chains.BtcRegtestChain().ChainId) + require.NoError(t, err) + expectedBtcAddress, err := crypto.GetTssAddrBTC(tss.TssPubkey, expectedBitcoinParams) + require.NoError(t, err) + expectedEthAddress, err := crypto.GetTssAddrEVM(tss.TssPubkey) + require.NoError(t, err) + require.NotNil(t, &types.QueryGetTssAddressByFinalizedHeightResponse{ + Eth: expectedEthAddress.String(), + Btc: expectedBtcAddress, + }, res) + }) +} + +func TestKeeper_GetTssAddressByFinalizedHeight(t *testing.T) { + t.Run("should error if req is nil", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.GetTssAddressByFinalizedHeight(wctx, nil) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should error if not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + res, err := k.GetTssAddressByFinalizedHeight(wctx, &types.QueryGetTssAddressByFinalizedHeightRequest{ + BitcoinChainId: 1, + }) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should error if invalid chain id", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + tssList := sample.TssList(100) + r := rand.Intn((len(tssList)-1)-0) + 0 + for _, tss := range tssList { + k.SetTSSHistory(ctx, tss) + } + + res, err := k.GetTssAddressByFinalizedHeight(wctx, &types.QueryGetTssAddressByFinalizedHeightRequest{ + BitcoinChainId: 987, + FinalizedZetaHeight: tssList[r].FinalizedZetaHeight, + }) + require.Nil(t, res) + require.Error(t, err) + }) + + t.Run("should return if valid chain id", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + + tssList := sample.TssList(100) + r := rand.Intn((len(tssList)-1)-0) + 0 + for _, tss := range tssList { + k.SetTSSHistory(ctx, tss) + } + + res, err := k.GetTssAddressByFinalizedHeight(wctx, &types.QueryGetTssAddressByFinalizedHeightRequest{ + BitcoinChainId: chains.BtcRegtestChain().ChainId, + FinalizedZetaHeight: tssList[r].FinalizedZetaHeight, + }) + require.NoError(t, err) + expectedBitcoinParams, err := chains.BitcoinNetParamsFromChainID(chains.BtcRegtestChain().ChainId) + require.NoError(t, err) + expectedBtcAddress, err := crypto.GetTssAddrBTC(tssList[r].TssPubkey, expectedBitcoinParams) + require.NoError(t, err) + expectedEthAddress, err := crypto.GetTssAddrEVM(tssList[r].TssPubkey) + require.NoError(t, err) + require.NotNil(t, &types.QueryGetTssAddressByFinalizedHeightResponse{ + Eth: expectedEthAddress.String(), + Btc: expectedBtcAddress, + }, res) + }) +} diff --git a/x/observer/keeper/hooks_test.go b/x/observer/keeper/hooks_test.go new file mode 100644 index 0000000000..ec782663ee --- /dev/null +++ b/x/observer/keeper/hooks_test.go @@ -0,0 +1,213 @@ +package keeper_test + +import ( + "math/rand" + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" + "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/testutil/sample" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestKeeper_NotImplementedHooks(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + hooks := k.Hooks() + require.Nil(t, hooks.AfterValidatorCreated(ctx, nil)) + require.Nil(t, hooks.BeforeValidatorModified(ctx, nil)) + require.Nil(t, hooks.AfterValidatorBonded(ctx, nil, nil)) + require.Nil(t, hooks.BeforeDelegationCreated(ctx, nil, nil)) + require.Nil(t, hooks.BeforeDelegationSharesModified(ctx, nil, nil)) + require.Nil(t, hooks.BeforeDelegationRemoved(ctx, nil, nil)) +} + +func TestKeeper_AfterValidatorRemoved(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + // #nosec G404 test purpose - weak randomness is not an issue here + r := rand.New(rand.NewSource(1)) + valAddr := sample.ValAddress(r) + accAddress, err := types.GetAccAddressFromOperatorAddress(valAddr.String()) + require.NoError(t, err) + os := types.ObserverSet{ + ObserverList: []string{accAddress.String()}, + } + k.SetObserverSet(ctx, os) + + hooks := k.Hooks() + err = hooks.AfterValidatorRemoved(ctx, nil, valAddr) + require.NoError(t, err) + + os, found := k.GetObserverSet(ctx) + require.True(t, found) + // observer for validator is removed from set + require.Empty(t, os.ObserverList) +} + +func TestKeeper_AfterValidatorBeginUnbonding(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + r := rand.New(rand.NewSource(9)) + validator := sample.Validator(t, r) + validator.DelegatorShares = sdk.NewDec(100) + k.GetStakingKeeper().SetValidator(ctx, validator) + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + k.GetStakingKeeper().SetDelegation(ctx, stakingtypes.Delegation{ + DelegatorAddress: accAddressOfValidator.String(), + ValidatorAddress: validator.GetOperator().String(), + Shares: sdk.NewDec(10), + }) + + k.SetObserverSet(ctx, types.ObserverSet{ + ObserverList: []string{accAddressOfValidator.String()}, + }) + + hooks := k.Hooks() + err = hooks.AfterValidatorBeginUnbonding(ctx, nil, validator.GetOperator()) + require.NoError(t, err) + + os, found := k.GetObserverSet(ctx) + require.True(t, found) + require.Empty(t, os.ObserverList) +} + +func TestKeeper_AfterDelegationModified(t *testing.T) { + t.Run("should not clean observer if not self delegation", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + r := rand.New(rand.NewSource(9)) + validator := sample.Validator(t, r) + validator.DelegatorShares = sdk.NewDec(100) + k.GetStakingKeeper().SetValidator(ctx, validator) + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + k.GetStakingKeeper().SetDelegation(ctx, stakingtypes.Delegation{ + DelegatorAddress: accAddressOfValidator.String(), + ValidatorAddress: validator.GetOperator().String(), + Shares: sdk.NewDec(10), + }) + + k.SetObserverSet(ctx, types.ObserverSet{ + ObserverList: []string{accAddressOfValidator.String()}, + }) + + hooks := k.Hooks() + err = hooks.AfterDelegationModified(ctx, sdk.AccAddress(sample.AccAddress()), validator.GetOperator()) + require.NoError(t, err) + + os, found := k.GetObserverSet(ctx) + require.True(t, found) + require.Equal(t, 1, len(os.ObserverList)) + require.Equal(t, accAddressOfValidator.String(), os.ObserverList[0]) + }) + + t.Run("should clean observer if self delegation", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + r := rand.New(rand.NewSource(9)) + validator := sample.Validator(t, r) + validator.DelegatorShares = sdk.NewDec(100) + k.GetStakingKeeper().SetValidator(ctx, validator) + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + k.GetStakingKeeper().SetDelegation(ctx, stakingtypes.Delegation{ + DelegatorAddress: accAddressOfValidator.String(), + ValidatorAddress: validator.GetOperator().String(), + Shares: sdk.NewDec(10), + }) + + k.SetObserverSet(ctx, types.ObserverSet{ + ObserverList: []string{accAddressOfValidator.String()}, + }) + + hooks := k.Hooks() + err = hooks.AfterDelegationModified(ctx, accAddressOfValidator, validator.GetOperator()) + require.NoError(t, err) + + os, found := k.GetObserverSet(ctx) + require.True(t, found) + require.Empty(t, os.ObserverList) + }) +} + +func TestKeeper_BeforeValidatorSlashed(t *testing.T) { + t.Run("should error if validator not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + r := rand.New(rand.NewSource(9)) + validator := sample.Validator(t, r) + + hooks := k.Hooks() + err := hooks.BeforeValidatorSlashed(ctx, validator.GetOperator(), sdk.NewDec(1)) + require.Error(t, err) + }) + + t.Run("should not error if observer set not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + r := rand.New(rand.NewSource(9)) + validator := sample.Validator(t, r) + validator.DelegatorShares = sdk.NewDec(100) + k.GetStakingKeeper().SetValidator(ctx, validator) + + hooks := k.Hooks() + err := hooks.BeforeValidatorSlashed(ctx, validator.GetOperator(), sdk.NewDec(1)) + require.NoError(t, err) + }) + + t.Run("should remove from observer set", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + r := rand.New(rand.NewSource(9)) + validator := sample.Validator(t, r) + validator.DelegatorShares = sdk.NewDec(100) + validator.Tokens = sdk.NewInt(100) + k.GetStakingKeeper().SetValidator(ctx, validator) + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + k.SetObserverSet(ctx, types.ObserverSet{ + ObserverList: []string{accAddressOfValidator.String()}, + }) + + hooks := k.Hooks() + err = hooks.BeforeValidatorSlashed(ctx, validator.GetOperator(), sdk.MustNewDecFromStr("0.1")) + require.NoError(t, err) + + os, found := k.GetObserverSet(ctx) + require.True(t, found) + require.Empty(t, os.ObserverList) + }) + + t.Run("should not remove from observer set", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + r := rand.New(rand.NewSource(9)) + validator := sample.Validator(t, r) + validator.DelegatorShares = sdk.NewDec(100) + validator.Tokens = sdk.NewInt(1000000000000000000) + k.GetStakingKeeper().SetValidator(ctx, validator) + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + k.SetObserverSet(ctx, types.ObserverSet{ + ObserverList: []string{accAddressOfValidator.String()}, + }) + + hooks := k.Hooks() + err = hooks.BeforeValidatorSlashed(ctx, validator.GetOperator(), sdk.MustNewDecFromStr("0")) + require.NoError(t, err) + + os, found := k.GetObserverSet(ctx) + require.True(t, found) + require.Equal(t, 1, len(os.ObserverList)) + require.Equal(t, accAddressOfValidator.String(), os.ObserverList[0]) + }) +} diff --git a/x/observer/keeper/keeper_test.go b/x/observer/keeper/keeper_test.go deleted file mode 100644 index c34cdb5392..0000000000 --- a/x/observer/keeper/keeper_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package keeper - -import ( - "testing" - - "github.com/cosmos/cosmos-sdk/codec" - codectypes "github.com/cosmos/cosmos-sdk/codec/types" - "github.com/cosmos/cosmos-sdk/store" - storetypes "github.com/cosmos/cosmos-sdk/store/types" - sdk "github.com/cosmos/cosmos-sdk/types" - typesparams "github.com/cosmos/cosmos-sdk/x/params/types" - slashingkeeper "github.com/cosmos/cosmos-sdk/x/slashing/keeper" - stakingkeeper "github.com/cosmos/cosmos-sdk/x/staking/keeper" - "github.com/stretchr/testify/require" - tmproto "github.com/tendermint/tendermint/proto/tendermint/types" - tmdb "github.com/tendermint/tm-db" - authoritykeeper "github.com/zeta-chain/zetacore/x/authority/keeper" - "github.com/zeta-chain/zetacore/x/observer/types" -) - -func SetupKeeper(t testing.TB) (*Keeper, sdk.Context) { - storeKey := sdk.NewKVStoreKey(types.StoreKey) - memStoreKey := storetypes.NewMemoryStoreKey(types.MemStoreKey) - - db := tmdb.NewMemDB() - stateStore := store.NewCommitMultiStore(db) - stateStore.MountStoreWithDB(storeKey, storetypes.StoreTypeIAVL, db) - stateStore.MountStoreWithDB(memStoreKey, storetypes.StoreTypeMemory, nil) - require.NoError(t, stateStore.LoadLatestVersion()) - - registry := codectypes.NewInterfaceRegistry() - cdc := codec.NewProtoCodec(registry) - - paramsSubspace := typesparams.NewSubspace(cdc, - types.Amino, - storeKey, - memStoreKey, - "ZetaObsParams", - ) - - k := NewKeeper( - codec.NewProtoCodec(registry), - storeKey, - memStoreKey, - paramsSubspace, - stakingkeeper.Keeper{}, - slashingkeeper.Keeper{}, - authoritykeeper.Keeper{}, - ) - - ctx := sdk.NewContext(stateStore, tmproto.Header{}, false, nil) - return k, ctx -} diff --git a/x/observer/keeper/keygen_test.go b/x/observer/keeper/keygen_test.go index f14fa3aeef..ef9a5e6991 100644 --- a/x/observer/keeper/keygen_test.go +++ b/x/observer/keeper/keygen_test.go @@ -1,4 +1,4 @@ -package keeper +package keeper_test import ( "testing" @@ -6,11 +6,13 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/x/observer/keeper" "github.com/zeta-chain/zetacore/x/observer/types" ) // Keeper Tests -func createTestKeygen(keeper *Keeper, ctx sdk.Context) types.Keygen { +func createTestKeygen(keeper *keeper.Keeper, ctx sdk.Context) types.Keygen { item := types.Keygen{ BlockNumber: 10, } @@ -19,16 +21,16 @@ func createTestKeygen(keeper *Keeper, ctx sdk.Context) types.Keygen { } func TestKeygenGet(t *testing.T) { - keeper, ctx := SetupKeeper(t) - item := createTestKeygen(keeper, ctx) - rst, found := keeper.GetKeygen(ctx) + k, ctx, _, _ := keepertest.ObserverKeeper(t) + item := createTestKeygen(k, ctx) + rst, found := k.GetKeygen(ctx) require.True(t, found) require.Equal(t, item, rst) } func TestKeygenRemove(t *testing.T) { - keeper, ctx := SetupKeeper(t) - createTestKeygen(keeper, ctx) - keeper.RemoveKeygen(ctx) - _, found := keeper.GetKeygen(ctx) + k, ctx, _, _ := keepertest.ObserverKeeper(t) + createTestKeygen(k, ctx) + k.RemoveKeygen(ctx) + _, found := k.GetKeygen(ctx) require.False(t, found) } diff --git a/x/observer/keeper/msg_server_add_blame_vote_test.go b/x/observer/keeper/msg_server_add_blame_vote_test.go new file mode 100644 index 0000000000..28ada869b0 --- /dev/null +++ b/x/observer/keeper/msg_server_add_blame_vote_test.go @@ -0,0 +1,182 @@ +package keeper_test + +import ( + "math/rand" + "testing" + "time" + + sdk "github.com/cosmos/cosmos-sdk/types" + slashingtypes "github.com/cosmos/cosmos-sdk/x/slashing/types" + "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/testutil/sample" + "github.com/zeta-chain/zetacore/x/observer/keeper" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestMsgServer_AddBlameVote(t *testing.T) { + t.Run("should error if supported chain not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + srv := keeper.NewMsgServerImpl(*k) + + res, err := srv.AddBlameVote(ctx, &types.MsgAddBlameVote{ + ChainId: 1, + }) + require.Error(t, err) + require.Nil(t, res) + }) + + t.Run("should error if not tombstoned observer", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + srv := keeper.NewMsgServerImpl(*k) + + chainId := getValidEthChainIDWithIndex(t, 0) + setSupportedChain(ctx, *k, chainId) + + res, err := srv.AddBlameVote(ctx, &types.MsgAddBlameVote{ + ChainId: chainId, + }) + require.Error(t, err) + require.Nil(t, res) + }) + + t.Run("should return response and set blame if finalizing vote", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + srv := keeper.NewMsgServerImpl(*k) + + chainId := getValidEthChainIDWithIndex(t, 0) + setSupportedChain(ctx, *k, chainId) + + r := rand.New(rand.NewSource(9)) + // Set validator in the store + validator := sample.Validator(t, r) + k.GetStakingKeeper().SetValidator(ctx, validator) + consAddress, err := validator.GetConsAddr() + require.NoError(t, err) + k.GetSlashingKeeper().SetValidatorSigningInfo(ctx, consAddress, slashingtypes.ValidatorSigningInfo{ + Address: consAddress.String(), + StartHeight: 0, + JailedUntil: ctx.BlockHeader().Time.Add(1000000 * time.Second), + Tombstoned: false, + MissedBlocksCounter: 1, + }) + + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + k.SetObserverSet(ctx, types.ObserverSet{ + ObserverList: []string{accAddressOfValidator.String()}, + }) + + blameInfo := sample.BlameRecord(t, "index") + res, err := srv.AddBlameVote(ctx, &types.MsgAddBlameVote{ + Creator: accAddressOfValidator.String(), + ChainId: chainId, + BlameInfo: blameInfo, + }) + require.NoError(t, err) + require.Equal(t, &types.MsgAddBlameVoteResponse{}, res) + + blame, found := k.GetBlame(ctx, blameInfo.Index) + require.True(t, found) + require.Equal(t, blameInfo, blame) + }) + + t.Run("should error if add vote fails", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + srv := keeper.NewMsgServerImpl(*k) + + chainId := getValidEthChainIDWithIndex(t, 1) + setSupportedChain(ctx, *k, chainId) + + r := rand.New(rand.NewSource(9)) + // Set validator in the store + validator := sample.Validator(t, r) + k.GetStakingKeeper().SetValidator(ctx, validator) + consAddress, err := validator.GetConsAddr() + require.NoError(t, err) + k.GetSlashingKeeper().SetValidatorSigningInfo(ctx, consAddress, slashingtypes.ValidatorSigningInfo{ + Address: consAddress.String(), + StartHeight: 0, + JailedUntil: ctx.BlockHeader().Time.Add(1000000 * time.Second), + Tombstoned: false, + MissedBlocksCounter: 1, + }) + + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + k.SetObserverSet(ctx, types.ObserverSet{ + ObserverList: []string{accAddressOfValidator.String(), "Observer2"}, + }) + blameInfo := sample.BlameRecord(t, "index") + vote := &types.MsgAddBlameVote{ + Creator: accAddressOfValidator.String(), + ChainId: chainId, + BlameInfo: blameInfo, + } + ballot := types.Ballot{ + Index: vote.Digest(), + BallotIdentifier: vote.Digest(), + VoterList: []string{accAddressOfValidator.String()}, + Votes: []types.VoteType{types.VoteType_SuccessObservation}, + BallotStatus: types.BallotStatus_BallotInProgress, + BallotThreshold: sdk.NewDec(2), + } + k.SetBallot(ctx, &ballot) + + _, err = srv.AddBlameVote(ctx, vote) + require.Error(t, err) + }) + + t.Run("should return response and not set blame if not finalizing vote", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + srv := keeper.NewMsgServerImpl(*k) + + chainId := getValidEthChainIDWithIndex(t, 1) + setSupportedChain(ctx, *k, chainId) + + r := rand.New(rand.NewSource(9)) + // Set validator in the store + validator := sample.Validator(t, r) + k.GetStakingKeeper().SetValidator(ctx, validator) + consAddress, err := validator.GetConsAddr() + require.NoError(t, err) + k.GetSlashingKeeper().SetValidatorSigningInfo(ctx, consAddress, slashingtypes.ValidatorSigningInfo{ + Address: consAddress.String(), + StartHeight: 0, + JailedUntil: ctx.BlockHeader().Time.Add(1000000 * time.Second), + Tombstoned: false, + MissedBlocksCounter: 1, + }) + + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + k.SetObserverSet(ctx, types.ObserverSet{ + ObserverList: []string{accAddressOfValidator.String(), "Observer2"}, + }) + blameInfo := sample.BlameRecord(t, "index") + vote := &types.MsgAddBlameVote{ + Creator: accAddressOfValidator.String(), + ChainId: chainId, + BlameInfo: blameInfo, + } + ballot := types.Ballot{ + Index: vote.Digest(), + BallotIdentifier: vote.Digest(), + VoterList: []string{accAddressOfValidator.String()}, + Votes: []types.VoteType{types.VoteType_NotYetVoted}, + BallotStatus: types.BallotStatus_BallotInProgress, + BallotThreshold: sdk.NewDec(2), + } + k.SetBallot(ctx, &ballot) + + res, err := srv.AddBlameVote(ctx, vote) + require.NoError(t, err) + require.Equal(t, &types.MsgAddBlameVoteResponse{}, res) + + _, found := k.GetBlame(ctx, blameInfo.Index) + require.False(t, found) + }) +} diff --git a/x/observer/keeper/msg_server_add_observer_test.go b/x/observer/keeper/msg_server_add_observer_test.go new file mode 100644 index 0000000000..0cbb9df514 --- /dev/null +++ b/x/observer/keeper/msg_server_add_observer_test.go @@ -0,0 +1,117 @@ +package keeper_test + +import ( + "math" + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/crypto" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/testutil/sample" + authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" + "github.com/zeta-chain/zetacore/x/observer/keeper" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestMsgServer_AddObserver(t *testing.T) { + t.Run("should error if not authorized", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeperWithMocks(t, keepertest.ObserverMockOptions{ + UseAuthorityMock: true, + }) + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + admin := sample.AccAddress() + keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) + wctx := sdk.WrapSDKContext(ctx) + + srv := keeper.NewMsgServerImpl(*k) + res, err := srv.AddObserver(wctx, &types.MsgAddObserver{ + Creator: admin, + }) + require.Error(t, err) + require.Equal(t, &types.MsgAddObserverResponse{}, res) + }) + + t.Run("should error if pub key not valid", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeperWithMocks(t, keepertest.ObserverMockOptions{ + UseAuthorityMock: true, + }) + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + admin := sample.AccAddress() + keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) + wctx := sdk.WrapSDKContext(ctx) + + srv := keeper.NewMsgServerImpl(*k) + res, err := srv.AddObserver(wctx, &types.MsgAddObserver{ + Creator: admin, + ZetaclientGranteePubkey: "invalid", + }) + require.Error(t, err) + require.Equal(t, &types.MsgAddObserverResponse{}, res) + }) + + t.Run("should add if add node account only false", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeperWithMocks(t, keepertest.ObserverMockOptions{ + UseAuthorityMock: true, + }) + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + admin := sample.AccAddress() + keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) + wctx := sdk.WrapSDKContext(ctx) + + _, found := k.GetLastObserverCount(ctx) + require.False(t, found) + srv := keeper.NewMsgServerImpl(*k) + observerAddress := sdk.AccAddress(crypto.AddressHash([]byte("ObserverAddress"))) + res, err := srv.AddObserver(wctx, &types.MsgAddObserver{ + Creator: admin, + ZetaclientGranteePubkey: sample.PubKeyString(), + AddNodeAccountOnly: false, + ObserverAddress: observerAddress.String(), + }) + require.NoError(t, err) + require.Equal(t, &types.MsgAddObserverResponse{}, res) + + loc, found := k.GetLastObserverCount(ctx) + require.True(t, found) + require.Equal(t, uint64(1), loc.Count) + }) + + t.Run("should add to node account if add node account only true", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeperWithMocks(t, keepertest.ObserverMockOptions{ + UseAuthorityMock: true, + }) + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + admin := sample.AccAddress() + keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) + wctx := sdk.WrapSDKContext(ctx) + + _, found := k.GetLastObserverCount(ctx) + require.False(t, found) + srv := keeper.NewMsgServerImpl(*k) + observerAddress := sdk.AccAddress(crypto.AddressHash([]byte("ObserverAddress"))) + _, found = k.GetKeygen(ctx) + require.False(t, found) + _, found = k.GetNodeAccount(ctx, observerAddress.String()) + require.False(t, found) + + res, err := srv.AddObserver(wctx, &types.MsgAddObserver{ + Creator: admin, + ZetaclientGranteePubkey: sample.PubKeyString(), + AddNodeAccountOnly: true, + ObserverAddress: observerAddress.String(), + }) + require.NoError(t, err) + require.Equal(t, &types.MsgAddObserverResponse{}, res) + + _, found = k.GetLastObserverCount(ctx) + require.False(t, found) + + keygen, found := k.GetKeygen(ctx) + require.True(t, found) + require.Equal(t, types.Keygen{BlockNumber: math.MaxInt64}, keygen) + + _, found = k.GetNodeAccount(ctx, observerAddress.String()) + require.True(t, found) + }) +} diff --git a/x/observer/keeper/msg_server_update_keygen_test.go b/x/observer/keeper/msg_server_update_keygen_test.go new file mode 100644 index 0000000000..8b3201c5ed --- /dev/null +++ b/x/observer/keeper/msg_server_update_keygen_test.go @@ -0,0 +1,105 @@ +package keeper_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/testutil/sample" + authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" + "github.com/zeta-chain/zetacore/x/observer/keeper" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestMsgServer_UpdateKeygen(t *testing.T) { + t.Run("should error if not authorized", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeperWithMocks(t, keepertest.ObserverMockOptions{ + UseAuthorityMock: true, + }) + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + admin := sample.AccAddress() + keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, false) + wctx := sdk.WrapSDKContext(ctx) + + srv := keeper.NewMsgServerImpl(*k) + res, err := srv.UpdateKeygen(wctx, &types.MsgUpdateKeygen{ + Creator: admin, + }) + require.Error(t, err) + require.Equal(t, &types.MsgUpdateKeygenResponse{}, res) + }) + + t.Run("should error if keygen not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeperWithMocks(t, keepertest.ObserverMockOptions{ + UseAuthorityMock: true, + }) + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + admin := sample.AccAddress() + keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) + wctx := sdk.WrapSDKContext(ctx) + + srv := keeper.NewMsgServerImpl(*k) + res, err := srv.UpdateKeygen(wctx, &types.MsgUpdateKeygen{ + Creator: admin, + }) + require.Error(t, err) + require.Nil(t, res) + }) + + t.Run("should error if msg block too low", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeperWithMocks(t, keepertest.ObserverMockOptions{ + UseAuthorityMock: true, + }) + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + admin := sample.AccAddress() + keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) + wctx := sdk.WrapSDKContext(ctx) + item := types.Keygen{ + BlockNumber: 10, + } + k.SetKeygen(ctx, item) + srv := keeper.NewMsgServerImpl(*k) + res, err := srv.UpdateKeygen(wctx, &types.MsgUpdateKeygen{ + Creator: admin, + Block: 2, + }) + require.Error(t, err) + require.Nil(t, res) + }) + + t.Run("should update", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeperWithMocks(t, keepertest.ObserverMockOptions{ + UseAuthorityMock: true, + }) + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + admin := sample.AccAddress() + keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) + wctx := sdk.WrapSDKContext(ctx) + item := types.Keygen{ + BlockNumber: 10, + } + k.SetKeygen(ctx, item) + srv := keeper.NewMsgServerImpl(*k) + + granteePubKey := sample.PubKeySet() + k.SetNodeAccount(ctx, types.NodeAccount{ + Operator: "operator", + GranteePubkey: granteePubKey, + }) + + res, err := srv.UpdateKeygen(wctx, &types.MsgUpdateKeygen{ + Creator: admin, + Block: ctx.BlockHeight() + 30, + }) + require.NoError(t, err) + require.Equal(t, &types.MsgUpdateKeygenResponse{}, res) + + keygen, found := k.GetKeygen(ctx) + require.True(t, found) + require.Equal(t, 1, len(keygen.GranteePubkeys)) + require.Equal(t, granteePubKey.Secp256k1.String(), keygen.GranteePubkeys[0]) + require.Equal(t, ctx.BlockHeight()+30, keygen.BlockNumber) + require.Equal(t, types.KeygenStatus_PendingKeygen, keygen.Status) + }) +} diff --git a/x/observer/keeper/node_account_test.go b/x/observer/keeper/node_account_test.go index 26a5c47d1c..c20501f8b2 100644 --- a/x/observer/keeper/node_account_test.go +++ b/x/observer/keeper/node_account_test.go @@ -1,4 +1,4 @@ -package keeper +package keeper_test import ( "fmt" @@ -6,12 +6,14 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/x/observer/keeper" "github.com/zeta-chain/zetacore/x/observer/types" ) // Keeper Tests -func createNNodeAccount(keeper *Keeper, ctx sdk.Context, n int) []types.NodeAccount { +func createNNodeAccount(keeper *keeper.Keeper, ctx sdk.Context, n int) []types.NodeAccount { items := make([]types.NodeAccount, n) for i := range items { items[i].Operator = fmt.Sprintf("%d", i) @@ -21,26 +23,28 @@ func createNNodeAccount(keeper *Keeper, ctx sdk.Context, n int) []types.NodeAcco } func TestNodeAccountGet(t *testing.T) { - keeper, ctx := SetupKeeper(t) - items := createNNodeAccount(keeper, ctx, 10) + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + items := createNNodeAccount(k, ctx, 10) for _, item := range items { - rst, found := keeper.GetNodeAccount(ctx, item.Operator) + rst, found := k.GetNodeAccount(ctx, item.Operator) require.True(t, found) require.Equal(t, item, rst) } } func TestNodeAccountRemove(t *testing.T) { - keeper, ctx := SetupKeeper(t) - items := createNNodeAccount(keeper, ctx, 10) + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + items := createNNodeAccount(k, ctx, 10) for _, item := range items { - keeper.RemoveNodeAccount(ctx, item.Operator) - _, found := keeper.GetNodeAccount(ctx, item.Operator) + k.RemoveNodeAccount(ctx, item.Operator) + _, found := k.GetNodeAccount(ctx, item.Operator) require.False(t, found) } } func TestNodeAccountGetAll(t *testing.T) { - keeper, ctx := SetupKeeper(t) - items := createNNodeAccount(keeper, ctx, 10) - require.Equal(t, items, keeper.GetAllNodeAccount(ctx)) + k, ctx, _, _ := keepertest.ObserverKeeper(t) + items := createNNodeAccount(k, ctx, 10) + require.Equal(t, items, k.GetAllNodeAccount(ctx)) } diff --git a/x/observer/keeper/nonce_to_cctx_test.go b/x/observer/keeper/nonce_to_cctx_test.go index 760e4f7111..a16a21f5ec 100644 --- a/x/observer/keeper/nonce_to_cctx_test.go +++ b/x/observer/keeper/nonce_to_cctx_test.go @@ -20,6 +20,14 @@ func TestKeeper_GetNonceToCctx(t *testing.T) { require.True(t, found) require.Equal(t, n, rst) } + + for _, n := range nonceToCctxList { + k.RemoveNonceToCctx(ctx, n) + } + for _, n := range nonceToCctxList { + _, found := k.GetNonceToCctx(ctx, n.Tss, n.ChainId, n.Nonce) + require.False(t, found) + } }) t.Run("Get nonce to cctx not found", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) diff --git a/x/observer/keeper/observer_set_test.go b/x/observer/keeper/observer_set_test.go index cd4a4239a7..d17e496fae 100644 --- a/x/observer/keeper/observer_set_test.go +++ b/x/observer/keeper/observer_set_test.go @@ -12,6 +12,8 @@ func TestKeeper_GetObserverSet(t *testing.T) { t.Run("get observer set", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) os := sample.ObserverSet(10) + _, found := k.GetObserverSet(ctx) + require.False(t, found) k.SetObserverSet(ctx, os) tfm, found := k.GetObserverSet(ctx) require.True(t, found) @@ -23,6 +25,7 @@ func TestKeeper_IsAddressPartOfObserverSet(t *testing.T) { t.Run("address is part of observer set", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) os := sample.ObserverSet(10) + require.False(t, k.IsAddressPartOfObserverSet(ctx, os.ObserverList[0])) k.SetObserverSet(ctx, os) require.True(t, k.IsAddressPartOfObserverSet(ctx, os.ObserverList[0])) require.False(t, k.IsAddressPartOfObserverSet(ctx, sample.AccAddress())) @@ -42,12 +45,30 @@ func TestKeeper_AddObserverToSet(t *testing.T) { require.True(t, found) require.Len(t, osNew.ObserverList, len(os.ObserverList)+1) }) + + t.Run("add observer to set if set doesn't exist", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + newObserver := sample.AccAddress() + k.AddObserverToSet(ctx, newObserver) + require.True(t, k.IsAddressPartOfObserverSet(ctx, newObserver)) + osNew, found := k.GetObserverSet(ctx) + require.True(t, found) + require.Len(t, osNew.ObserverList, 1) + + // add same address again, len doesn't change + k.AddObserverToSet(ctx, newObserver) + require.True(t, k.IsAddressPartOfObserverSet(ctx, newObserver)) + osNew, found = k.GetObserverSet(ctx) + require.True(t, found) + require.Len(t, osNew.ObserverList, 1) + }) } func TestKeeper_RemoveObserverFromSet(t *testing.T) { t.Run("remove observer from set", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) os := sample.ObserverSet(10) + k.RemoveObserverFromSet(ctx, os.ObserverList[0]) k.SetObserverSet(ctx, os) k.RemoveObserverFromSet(ctx, os.ObserverList[0]) require.False(t, k.IsAddressPartOfObserverSet(ctx, os.ObserverList[0])) @@ -64,13 +85,25 @@ func TestKeeper_UpdateObserverAddress(t *testing.T) { newObserverAddress := sample.AccAddress() observerSet := sample.ObserverSet(10) observerSet.ObserverList = append(observerSet.ObserverList, oldObserverAddress) - k.SetObserverSet(ctx, observerSet) err := k.UpdateObserverAddress(ctx, oldObserverAddress, newObserverAddress) + require.Error(t, err) + k.SetObserverSet(ctx, observerSet) + err = k.UpdateObserverAddress(ctx, oldObserverAddress, newObserverAddress) require.NoError(t, err) observerSet, found := k.GetObserverSet(ctx) require.True(t, found) require.Equal(t, newObserverAddress, observerSet.ObserverList[len(observerSet.ObserverList)-1]) }) + t.Run("should error if observer address not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + oldObserverAddress := sample.AccAddress() + newObserverAddress := sample.AccAddress() + observerSet := sample.ObserverSet(10) + observerSet.ObserverList = append(observerSet.ObserverList, oldObserverAddress) + k.SetObserverSet(ctx, observerSet) + err := k.UpdateObserverAddress(ctx, sample.AccAddress(), newObserverAddress) + require.Error(t, err) + }) t.Run("update observer address long observerList", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) oldObserverAddress := sample.AccAddress() diff --git a/x/observer/keeper/params_test.go b/x/observer/keeper/params_test.go index c487ea0f73..309d66df9b 100644 --- a/x/observer/keeper/params_test.go +++ b/x/observer/keeper/params_test.go @@ -1,4 +1,4 @@ -package keeper +package keeper_test import ( "fmt" @@ -9,11 +9,12 @@ import ( "github.com/tendermint/tendermint/crypto" "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/x/observer/types" ) func TestGetParams(t *testing.T) { - k, ctx := SetupKeeper(t) + k, ctx, _, _ := keepertest.ObserverKeeper(t) params := types.DefaultParams() k.SetParams(ctx, params) @@ -22,7 +23,6 @@ func TestGetParams(t *testing.T) { } func TestGenerateAddress(t *testing.T) { - types.SetConfig(false) addr := sdk.AccAddress(crypto.AddressHash([]byte("Output1" + strconv.Itoa(1)))) addrString := addr.String() fmt.Println(addrString) diff --git a/x/observer/keeper/pending_nonces_test.go b/x/observer/keeper/pending_nonces_test.go index 6185e80bf1..37ff74fb22 100644 --- a/x/observer/keeper/pending_nonces_test.go +++ b/x/observer/keeper/pending_nonces_test.go @@ -62,5 +62,74 @@ func TestKeeper_PendingNoncesAll(t *testing.T) { return rst[i].ChainId < rst[j].ChainId }) require.Equal(t, nonces, rst) + + k.RemovePendingNonces(ctx, nonces[0]) + rst, err = k.GetAllPendingNonces(ctx) + require.NoError(t, err) + sort.SliceStable(rst, func(i, j int) bool { + return rst[i].ChainId < rst[j].ChainId + }) + require.Equal(t, nonces[1:], rst) + }) +} + +func TestKeeper_SetTssAndUpdateNonce(t *testing.T) { + t.Run("should set tss and update nonces", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + _, found := k.GetTSS(ctx) + require.False(t, found) + pendingNonces, err := k.GetAllPendingNonces(ctx) + require.NoError(t, err) + require.Empty(t, pendingNonces) + chainNonces := k.GetAllChainNonces(ctx) + require.NoError(t, err) + require.Empty(t, chainNonces) + + tss := sample.Tss() + // core params list but chain not in list + setSupportedChain(ctx, *k, getValidEthChainIDWithIndex(t, 0)) + k.SetTssAndUpdateNonce(ctx, tss) + + _, found = k.GetTSS(ctx) + require.True(t, found) + pendingNonces, err = k.GetAllPendingNonces(ctx) + require.NoError(t, err) + require.Equal(t, 1, len(pendingNonces)) + chainNonces = k.GetAllChainNonces(ctx) + require.NoError(t, err) + require.Equal(t, 1, len(chainNonces)) + }) +} + +func TestKeeper_RemoveFromPendingNonces(t *testing.T) { + t.Run("should remove from pending nonces", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + nonces := sample.PendingNoncesList(t, "sample", 10) + tss := sample.Tss() + // make nonces and pubkey deterministic for test + for i := range nonces { + nonces[i].NonceLow = int64(i) + nonces[i].NonceHigh = nonces[i].NonceLow + 3 + nonces[i].Tss = tss.TssPubkey + } + sort.SliceStable(nonces, func(i, j int) bool { + return nonces[i].ChainId < nonces[j].ChainId + }) + for _, nonce := range nonces { + k.SetPendingNonces(ctx, nonce) + } + + k.RemoveFromPendingNonces(ctx, tss.TssPubkey, 1, 1) + pendingNonces, err := k.GetAllPendingNonces(ctx) + require.NoError(t, err) + nonceUpdated := false + for _, pn := range pendingNonces { + if pn.ChainId == 1 { + require.Equal(t, int64(2), pn.NonceLow) + nonceUpdated = true + } + } + require.True(t, nonceUpdated) }) } diff --git a/x/observer/keeper/tss_funds_migrator_test.go b/x/observer/keeper/tss_funds_migrator_test.go index 1d32f19f92..620fa2ca06 100644 --- a/x/observer/keeper/tss_funds_migrator_test.go +++ b/x/observer/keeper/tss_funds_migrator_test.go @@ -12,10 +12,16 @@ func TestKeeper_GetTssFundMigrator(t *testing.T) { t.Run("Successfully set funds migrator for chain", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) chain := sample.TssFundsMigrator(1) + _, found := k.GetFundMigrator(ctx, chain.ChainId) + require.False(t, found) k.SetFundMigrator(ctx, chain) tfm, found := k.GetFundMigrator(ctx, chain.ChainId) require.True(t, found) require.Equal(t, chain, tfm) + + k.RemoveAllExistingMigrators(ctx) + _, found = k.GetFundMigrator(ctx, chain.ChainId) + require.False(t, found) }) t.Run("Verify only one migrator can be created for a chain", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) @@ -28,5 +34,4 @@ func TestKeeper_GetTssFundMigrator(t *testing.T) { require.Equal(t, 1, len(migratorList)) require.Equal(t, tfm2, migratorList[0]) }) - } diff --git a/x/observer/keeper/tss_test.go b/x/observer/keeper/tss_test.go index 6408f4f49a..a1f1ca6ddc 100644 --- a/x/observer/keeper/tss_test.go +++ b/x/observer/keeper/tss_test.go @@ -9,14 +9,10 @@ import ( "github.com/stretchr/testify/require" keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/testutil/sample" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - sdk "github.com/cosmos/cosmos-sdk/types" "github.com/zeta-chain/zetacore/x/observer/types" ) -func TestTSSGet(t *testing.T) { +func TestKeeper_GetTSS(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) tss := sample.Tss() k.SetTSS(ctx, tss) @@ -25,7 +21,7 @@ func TestTSSGet(t *testing.T) { require.Equal(t, tss, tssQueried) } -func TestTSSRemove(t *testing.T) { +func TestKeeper_RemoveTSS(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) tss := sample.Tss() k.SetTSS(ctx, tss) @@ -34,83 +30,34 @@ func TestTSSRemove(t *testing.T) { require.False(t, found) } -func TestTSSQuerySingle(t *testing.T) { +func TestKeeper_CheckIfTssPubkeyHasBeenGenerated(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) - wctx := sdk.WrapSDKContext(ctx) - //msgs := createTSS(keeper, ctx, 1) tss := sample.Tss() - k.SetTSS(ctx, tss) - for _, tc := range []struct { - desc string - request *types.QueryGetTSSRequest - response *types.QueryGetTSSResponse - err error - }{ - { - desc: "First", - request: &types.QueryGetTSSRequest{}, - response: &types.QueryGetTSSResponse{TSS: tss}, - }, - { - desc: "InvalidRequest", - err: status.Error(codes.InvalidArgument, "invalid request"), - }, - } { - tc := tc - t.Run(tc.desc, func(t *testing.T) { - response, err := k.TSS(wctx, tc.request) - if tc.err != nil { - require.ErrorIs(t, err, tc.err) - } else { - require.Equal(t, tc.response, response) - } - }) - } + + generated, found := k.CheckIfTssPubkeyHasBeenGenerated(ctx, tss.TssPubkey) + require.False(t, found) + require.Equal(t, types.TSS{}, generated) + + k.AppendTss(ctx, tss) + + generated, found = k.CheckIfTssPubkeyHasBeenGenerated(ctx, tss.TssPubkey) + require.True(t, found) + require.Equal(t, tss, generated) } -func TestTSSQueryHistory(t *testing.T) { +func TestKeeper_GetHistoricalTssByFinalizedHeight(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) - wctx := sdk.WrapSDKContext(ctx) - for _, tc := range []struct { - desc string - tssCount int - foundPrevious bool - err error - }{ - { - desc: "1 Tss addresses", - tssCount: 1, - foundPrevious: false, - err: nil, - }, - { - desc: "10 Tss addresses", - tssCount: 10, - foundPrevious: true, - err: nil, - }, - } { - tc := tc - t.Run(tc.desc, func(t *testing.T) { - tssList := sample.TssList(tc.tssCount) - for _, tss := range tssList { - k.SetTSS(ctx, tss) - k.SetTSSHistory(ctx, tss) - } - request := &types.QueryTssHistoryRequest{} - response, err := k.TssHistory(wctx, request) - if tc.err != nil { - require.ErrorIs(t, err, tc.err) - } else { - require.Equal(t, len(tssList), len(response.TssList)) - prevTss, found := k.GetPreviousTSS(ctx) - require.Equal(t, tc.foundPrevious, found) - if found { - require.Equal(t, tssList[len(tssList)-2], prevTss) - } - } - }) + tssList := sample.TssList(100) + r := rand.Intn((len(tssList)-1)-0) + 0 + _, found := k.GetHistoricalTssByFinalizedHeight(ctx, tssList[r].FinalizedZetaHeight) + require.False(t, found) + + for _, tss := range tssList { + k.SetTSSHistory(ctx, tss) } + tss, found := k.GetHistoricalTssByFinalizedHeight(ctx, tssList[r].FinalizedZetaHeight) + require.True(t, found) + require.Equal(t, tssList[r], tss) } func TestKeeper_TssHistory(t *testing.T) { @@ -165,15 +112,4 @@ func TestKeeper_TssHistory(t *testing.T) { }) require.Equal(t, tssList, rst) }) - t.Run("Get historical TSS", func(t *testing.T) { - k, ctx, _, _ := keepertest.ObserverKeeper(t) - tssList := sample.TssList(100) - for _, tss := range tssList { - k.SetTSSHistory(ctx, tss) - } - r := rand.Intn((len(tssList)-1)-0) + 0 - tss, found := k.GetHistoricalTssByFinalizedHeight(ctx, tssList[r].FinalizedZetaHeight) - require.True(t, found) - require.Equal(t, tssList[r], tss) - }) } diff --git a/x/observer/keeper/utils.go b/x/observer/keeper/utils.go index 322d45387f..566b559630 100644 --- a/x/observer/keeper/utils.go +++ b/x/observer/keeper/utils.go @@ -88,7 +88,7 @@ func (k Keeper) IsValidator(ctx sdk.Context, creator string) error { return types.ErrNotValidator } - if validator.Jailed == true || validator.IsBonded() == false { + if validator.Jailed || !validator.IsBonded() { return types.ErrValidatorStatus } return nil @@ -135,7 +135,6 @@ func (k Keeper) CheckObserverSelfDelegation(ctx sdk.Context, accAddress string) return err } tokens := validator.TokensFromShares(delegation.Shares) - if tokens.LT(minDelegation) { k.RemoveObserverFromSet(ctx, accAddress) } diff --git a/x/observer/keeper/utils_test.go b/x/observer/keeper/utils_test.go index 9d4b8c54e2..4aedd8ad13 100644 --- a/x/observer/keeper/utils_test.go +++ b/x/observer/keeper/utils_test.go @@ -7,6 +7,7 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" slashingtypes "github.com/cosmos/cosmos-sdk/x/slashing/types" + stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" "github.com/stretchr/testify/require" "github.com/zeta-chain/zetacore/pkg/chains" keepertest "github.com/zeta-chain/zetacore/testutil/keeper" @@ -61,13 +62,14 @@ func TestKeeper_IsAuthorized(t *testing.T) { }) accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) k.SetObserverSet(ctx, types.ObserverSet{ ObserverList: []string{accAddressOfValidator.String()}, }) require.True(t, k.IsNonTombstonedObserver(ctx, accAddressOfValidator.String())) - }) + t.Run("not authorized for tombstoned observer", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) @@ -87,13 +89,15 @@ func TestKeeper_IsAuthorized(t *testing.T) { }) accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + k.SetObserverSet(ctx, types.ObserverSet{ ObserverList: []string{accAddressOfValidator.String()}, }) require.False(t, k.IsNonTombstonedObserver(ctx, accAddressOfValidator.String())) - }) + t.Run("not authorized for non-validator observer", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) @@ -113,11 +117,205 @@ func TestKeeper_IsAuthorized(t *testing.T) { }) accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + k.SetObserverSet(ctx, types.ObserverSet{ ObserverList: []string{accAddressOfValidator.String()}, }) require.False(t, k.IsNonTombstonedObserver(ctx, accAddressOfValidator.String())) + }) +} + +func TestKeeper_CheckObserverSelfDelegation(t *testing.T) { + t.Run("should error if accAddress invalid", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + err := k.CheckObserverSelfDelegation(ctx, "invalid") + require.Error(t, err) + }) + + t.Run("should error if validator not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + accAddress := sample.AccAddress() + err := k.CheckObserverSelfDelegation(ctx, accAddress) + require.Error(t, err) + }) + + t.Run("should error if delegation not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + r := rand.New(rand.NewSource(9)) + validator := sample.Validator(t, r) + k.GetStakingKeeper().SetValidator(ctx, validator) + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + err = k.CheckObserverSelfDelegation(ctx, accAddressOfValidator.String()) + require.Error(t, err) + }) + + t.Run("should remove from observer list if tokens less than min del", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + r := rand.New(rand.NewSource(9)) + validator := sample.Validator(t, r) + validator.DelegatorShares = sdk.NewDec(100) + k.GetStakingKeeper().SetValidator(ctx, validator) + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + k.GetStakingKeeper().SetDelegation(ctx, stakingtypes.Delegation{ + DelegatorAddress: accAddressOfValidator.String(), + ValidatorAddress: validator.GetOperator().String(), + Shares: sdk.NewDec(10), + }) + + k.SetObserverSet(ctx, types.ObserverSet{ + ObserverList: []string{accAddressOfValidator.String()}, + }) + err = k.CheckObserverSelfDelegation(ctx, accAddressOfValidator.String()) + require.NoError(t, err) + + os, found := k.GetObserverSet(ctx) + require.True(t, found) + require.Empty(t, os.ObserverList) + }) + + t.Run("should not remove from observer list if tokens gte than min del", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + r := rand.New(rand.NewSource(9)) + validator := sample.Validator(t, r) + + validator.DelegatorShares = sdk.NewDec(1) + validator.Tokens = sdk.NewInt(1) + k.GetStakingKeeper().SetValidator(ctx, validator) + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + minDelegation, err := types.GetMinObserverDelegationDec() + require.NoError(t, err) + k.GetStakingKeeper().SetDelegation(ctx, stakingtypes.Delegation{ + DelegatorAddress: accAddressOfValidator.String(), + ValidatorAddress: validator.GetOperator().String(), + Shares: minDelegation, + }) + + k.SetObserverSet(ctx, types.ObserverSet{ + ObserverList: []string{accAddressOfValidator.String()}, + }) + err = k.CheckObserverSelfDelegation(ctx, accAddressOfValidator.String()) + require.NoError(t, err) + + os, found := k.GetObserverSet(ctx) + require.True(t, found) + require.Equal(t, 1, len(os.ObserverList)) + require.Equal(t, accAddressOfValidator.String(), os.ObserverList[0]) + }) +} + +func TestKeeper_IsOpeartorTombstoned(t *testing.T) { + t.Run("should err if invalid addr", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + res, err := k.IsOperatorTombstoned(ctx, "invalid") + require.Error(t, err) + require.False(t, res) + }) + + t.Run("should error if validator not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + accAddress := sample.AccAddress() + res, err := k.IsOperatorTombstoned(ctx, accAddress) + require.Error(t, err) + require.False(t, res) + }) + + t.Run("should not error if validator found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + r := rand.New(rand.NewSource(9)) + validator := sample.Validator(t, r) + k.GetStakingKeeper().SetValidator(ctx, validator) + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + res, err := k.IsOperatorTombstoned(ctx, accAddressOfValidator.String()) + require.NoError(t, err) + require.False(t, res) + }) +} + +func TestKeeper_IsValidator(t *testing.T) { + t.Run("should err if invalid addr", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + err := k.IsValidator(ctx, "invalid") + require.Error(t, err) + }) + + t.Run("should error if validator not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + accAddress := sample.AccAddress() + err := k.IsValidator(ctx, accAddress) + require.Error(t, err) + }) + + t.Run("should err if validator not bonded", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + r := rand.New(rand.NewSource(9)) + validator := sample.Validator(t, r) + k.GetStakingKeeper().SetValidator(ctx, validator) + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + err = k.IsValidator(ctx, accAddressOfValidator.String()) + require.Error(t, err) + }) + + t.Run("should err if validator jailed", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + r := rand.New(rand.NewSource(9)) + validator := sample.Validator(t, r) + validator.Status = stakingtypes.Bonded + validator.Jailed = true + k.GetStakingKeeper().SetValidator(ctx, validator) + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + err = k.IsValidator(ctx, accAddressOfValidator.String()) + require.Error(t, err) + }) + + t.Run("should not err if validator not jailed and bonded", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + r := rand.New(rand.NewSource(9)) + validator := sample.Validator(t, r) + validator.Status = stakingtypes.Bonded + validator.Jailed = false + k.GetStakingKeeper().SetValidator(ctx, validator) + accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress) + require.NoError(t, err) + + err = k.IsValidator(ctx, accAddressOfValidator.String()) + require.NoError(t, err) + }) +} + +func TestKeeper_FindBallot(t *testing.T) { + t.Run("should err if chain params not found", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + _, _, err := k.FindBallot(ctx, "index", &chains.Chain{ + ChainId: 987, + }, types.ObservationType_InBoundTx) + require.Error(t, err) }) } diff --git a/x/observer/keeper/vote_inbound_test.go b/x/observer/keeper/vote_inbound_test.go index 41a3999b75..5b1c09e81b 100644 --- a/x/observer/keeper/vote_inbound_test.go +++ b/x/observer/keeper/vote_inbound_test.go @@ -258,6 +258,57 @@ func TestKeeper_VoteOnInboundBallot(t *testing.T) { require.True(t, isNew) }) + t.Run("fail if can not add vote", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeperWithMocks(t, keepertest.ObserverMocksAll) + + observer := sample.AccAddress() + stakingMock := keepertest.GetObserverStakingMock(t, k) + slashingMock := keepertest.GetObserverSlashingMock(t, k) + + k.SetCrosschainFlags(ctx, types.CrosschainFlags{ + IsInboundEnabled: true, + }) + k.SetChainParamsList(ctx, types.ChainParamsList{ + ChainParams: []*types.ChainParams{ + { + ChainId: getValidEthChainIDWithIndex(t, 0), + IsSupported: true, + }, + { + ChainId: getValidEthChainIDWithIndex(t, 1), + IsSupported: true, + }, + }, + }) + k.SetObserverSet(ctx, types.ObserverSet{ + ObserverList: []string{observer}, + }) + stakingMock.MockGetValidator(sample.Validator(t, sample.Rand())) + slashingMock.MockIsTombstoned(false) + ballot := types.Ballot{ + Index: "index", + BallotIdentifier: "index", + VoterList: []string{observer}, + // already voted + Votes: []types.VoteType{types.VoteType_SuccessObservation}, + BallotStatus: types.BallotStatus_BallotInProgress, + BallotThreshold: sdk.NewDec(2), + } + k.SetBallot(ctx, &ballot) + isFinalized, isNew, err := k.VoteOnInboundBallot( + ctx, + getValidEthChainIDWithIndex(t, 0), + getValidEthChainIDWithIndex(t, 1), + coin.CoinType_ERC20, + observer, + "index", + "inTxHash", + ) + require.Error(t, err) + require.False(t, isFinalized) + require.False(t, isNew) + }) + t.Run("can add vote and create ballot without finalizing ballot", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeperWithMocks(t, keepertest.ObserverMocksAll) diff --git a/x/observer/keeper/vote_outbound_test.go b/x/observer/keeper/vote_outbound_test.go index fa34a11ca6..c92fec5927 100644 --- a/x/observer/keeper/vote_outbound_test.go +++ b/x/observer/keeper/vote_outbound_test.go @@ -133,6 +133,48 @@ func TestKeeper_VoteOnOutboundBallot(t *testing.T) { require.Equal(t, expectedBallot, ballot) }) + t.Run("fail if can not add vote", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeperWithMocks(t, keepertest.ObserverMocksAll) + + observer := sample.AccAddress() + stakingMock := keepertest.GetObserverStakingMock(t, k) + slashingMock := keepertest.GetObserverSlashingMock(t, k) + + k.SetChainParamsList(ctx, types.ChainParamsList{ + ChainParams: []*types.ChainParams{ + { + ChainId: getValidEthChainIDWithIndex(t, 0), + IsSupported: true, + }, + }, + }) + k.SetObserverSet(ctx, types.ObserverSet{ + ObserverList: []string{observer}, + }) + stakingMock.MockGetValidator(sample.Validator(t, sample.Rand())) + slashingMock.MockIsTombstoned(false) + ballot := types.Ballot{ + Index: "index", + BallotIdentifier: "index", + VoterList: []string{observer}, + // already voted + Votes: []types.VoteType{types.VoteType_SuccessObservation}, + BallotStatus: types.BallotStatus_BallotInProgress, + BallotThreshold: sdk.NewDec(2), + } + k.SetBallot(ctx, &ballot) + isFinalized, isNew, ballot, _, err := k.VoteOnOutboundBallot( + ctx, + "index", + getValidEthChainIDWithIndex(t, 0), + chains.ReceiveStatus_Success, + observer, + ) + require.Error(t, err) + require.False(t, isFinalized) + require.False(t, isNew) + }) + t.Run("can add vote and create ballot without finalizing ballot", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeperWithMocks(t, keepertest.ObserverMocksAll) diff --git a/x/observer/types/ballot.go b/x/observer/types/ballot.go index 811a1b7688..71512c9f11 100644 --- a/x/observer/types/ballot.go +++ b/x/observer/types/ballot.go @@ -14,12 +14,18 @@ func (m Ballot) AddVote(address string, vote VoteType) (Ballot, error) { // `index` is the index of the `address` in the `VoterList` // `index` is used to set the vote in the `Votes` array index := m.GetVoterIndex(address) + if index == -1 { + return m, cosmoserrors.Wrap(ErrUnableToAddVote, fmt.Sprintf("Voter %s not in voter list", address)) + } m.Votes[index] = vote return m, nil } func (m Ballot) HasVoted(address string) bool { index := m.GetVoterIndex(address) + if index == -1 { + return false + } return m.Votes[index] != VoteType_NotYetVoted } diff --git a/x/observer/types/ballot.pb.go b/x/observer/types/ballot.pb.go index 7e8c40afeb..8e9313d1dd 100644 --- a/x/observer/types/ballot.pb.go +++ b/x/observer/types/ballot.pb.go @@ -81,6 +81,7 @@ func (BallotStatus) EnumDescriptor() ([]byte, []int) { return fileDescriptor_9eac86b249c97b5b, []int{1} } +// https://github.com/zeta-chain/node/issues/939 type Ballot struct { Index string `protobuf:"bytes,1,opt,name=index,proto3" json:"index,omitempty"` BallotIdentifier string `protobuf:"bytes,2,opt,name=ballot_identifier,json=ballotIdentifier,proto3" json:"ballot_identifier,omitempty"` diff --git a/x/observer/types/ballot_test.go b/x/observer/types/ballot_test.go index 2bd6441c48..322affd728 100644 --- a/x/observer/types/ballot_test.go +++ b/x/observer/types/ballot_test.go @@ -21,6 +21,7 @@ func TestBallot_AddVote(t *testing.T) { finalVotes []VoteType finalStatus BallotStatus isFinalized bool + wantErr bool }{ { name: "All success", @@ -188,6 +189,18 @@ func TestBallot_AddVote(t *testing.T) { finalStatus: BallotStatus_BallotInProgress, isFinalized: false, }, + { + name: "Voter not in voter list", + threshold: sdk.MustNewDecFromStr("0.66"), + voterList: []string{}, + votes: []votes{ + {"Observer5", VoteType_SuccessObservation}, + }, + wantErr: true, + finalVotes: []VoteType{}, + finalStatus: BallotStatus_BallotInProgress, + isFinalized: false, + }, } for _, test := range tt { test := test @@ -202,7 +215,11 @@ func TestBallot_AddVote(t *testing.T) { BallotStatus: BallotStatus_BallotInProgress, } for _, vote := range test.votes { - ballot, _ = ballot.AddVote(vote.address, vote.vote) + b, err := ballot.AddVote(vote.address, vote.vote) + if test.wantErr { + require.Error(t, err) + } + ballot = b } finalBallot, isFinalized := ballot.IsFinalizingVote() diff --git a/x/observer/types/chain_params_test.go b/x/observer/types/chain_params_test.go index 0bdc9d814d..2100fe2071 100644 --- a/x/observer/types/chain_params_test.go +++ b/x/observer/types/chain_params_test.go @@ -3,6 +3,7 @@ package types_test import ( "testing" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/zeta-chain/zetacore/x/observer/types" @@ -196,4 +197,23 @@ func (s *UpdateChainParamsSuite) Validate(params *types.ChainParams) { copy.OutboundTxScheduleLookahead = 501 err = types.ValidateChainParams(©) require.NotNil(s.T(), err) + + copy = *params + copy.BallotThreshold = sdk.Dec{} + err = types.ValidateChainParams(©) + require.NotNil(s.T(), err) + copy.BallotThreshold = sdk.MustNewDecFromStr("1.2") + err = types.ValidateChainParams(©) + require.NotNil(s.T(), err) + copy.BallotThreshold = sdk.MustNewDecFromStr("0.9") + err = types.ValidateChainParams(©) + require.Nil(s.T(), err) + + copy = *params + copy.MinObserverDelegation = sdk.Dec{} + err = types.ValidateChainParams(©) + require.NotNil(s.T(), err) + copy.MinObserverDelegation = sdk.MustNewDecFromStr("0.9") + err = types.ValidateChainParams(©) + require.Nil(s.T(), err) } diff --git a/x/observer/types/crosschain_flags_test.go b/x/observer/types/crosschain_flags_test.go new file mode 100644 index 0000000000..9939f8b909 --- /dev/null +++ b/x/observer/types/crosschain_flags_test.go @@ -0,0 +1,19 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestDefaultDefaultCrosschainFlags(t *testing.T) { + defaultCrosschainFlags := types.DefaultCrosschainFlags() + + require.Equal(t, &types.CrosschainFlags{ + IsInboundEnabled: true, + IsOutboundEnabled: true, + GasPriceIncreaseFlags: &types.DefaultGasPriceIncreaseFlags, + BlockHeaderVerificationFlags: &types.DefaultBlockHeaderVerificationFlags, + }, defaultCrosschainFlags) +} diff --git a/x/observer/types/expected_keepers.go b/x/observer/types/expected_keepers.go index ba1c9ab9c2..5f402c2689 100644 --- a/x/observer/types/expected_keepers.go +++ b/x/observer/types/expected_keepers.go @@ -11,6 +11,7 @@ type StakingKeeper interface { GetValidator(ctx sdk.Context, addr sdk.ValAddress) (validator stakingtypes.Validator, found bool) GetDelegation(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) (delegation stakingtypes.Delegation, found bool) SetValidator(ctx sdk.Context, validator stakingtypes.Validator) + SetDelegation(ctx sdk.Context, delegation stakingtypes.Delegation) } type SlashingKeeper interface { diff --git a/x/observer/types/genesis_test.go b/x/observer/types/genesis_test.go index 1ba8212a35..676bca7d62 100644 --- a/x/observer/types/genesis_test.go +++ b/x/observer/types/genesis_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/testutil/sample" "github.com/zeta-chain/zetacore/x/observer/types" ) @@ -12,6 +13,14 @@ func TestGenesisState_Validate(t *testing.T) { chainParams := types.GetDefaultChainParams().ChainParams invalidChainParamsGen.ChainParamsList.ChainParams = append(chainParams, chainParams[0]) + gsWithDuplicateNodeAccountList := types.DefaultGenesis() + nodeAccount := sample.NodeAccount() + gsWithDuplicateNodeAccountList.NodeAccountList = []*types.NodeAccount{nodeAccount, nodeAccount} + + gsWithDuplicateChainNonces := types.DefaultGenesis() + chainNonce := sample.ChainNonces(t, "0") + gsWithDuplicateChainNonces.ChainNonces = []types.ChainNonces{chainNonce, chainNonce} + for _, tc := range []struct { desc string genState *types.GenesisState @@ -32,6 +41,16 @@ func TestGenesisState_Validate(t *testing.T) { genState: invalidChainParamsGen, valid: false, }, + { + desc: "invalid genesis state duplicate node account list", + genState: gsWithDuplicateNodeAccountList, + valid: false, + }, + { + desc: "invalid genesis state duplicate chain nonces", + genState: gsWithDuplicateChainNonces, + valid: false, + }, } { t.Run(tc.desc, func(t *testing.T) { err := tc.genState.Validate() diff --git a/x/observer/types/observer_set_test.go b/x/observer/types/observer_set_test.go new file mode 100644 index 0000000000..5efce2cd81 --- /dev/null +++ b/x/observer/types/observer_set_test.go @@ -0,0 +1,32 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/pkg/chains" + "github.com/zeta-chain/zetacore/testutil/sample" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestObserverSet(t *testing.T) { + observerSet := sample.ObserverSet(4) + + require.Equal(t, int(4), observerSet.Len()) + require.Equal(t, uint64(4), observerSet.LenUint()) + err := observerSet.Validate() + require.NoError(t, err) + + observerSet.ObserverList[0] = "invalid" + err = observerSet.Validate() + require.Error(t, err) +} + +func TestCheckReceiveStatus(t *testing.T) { + err := types.CheckReceiveStatus(chains.ReceiveStatus_Success) + require.NoError(t, err) + err = types.CheckReceiveStatus(chains.ReceiveStatus_Failed) + require.NoError(t, err) + err = types.CheckReceiveStatus(chains.ReceiveStatus_Created) + require.Error(t, err) +} diff --git a/x/observer/types/params.go b/x/observer/types/params.go index 7fc72c1e89..ea1db66657 100644 --- a/x/observer/types/params.go +++ b/x/observer/types/params.go @@ -100,6 +100,7 @@ func validateAdminPolicy(i interface{}) error { return nil } +// https://github.com/zeta-chain/node/issues/1983 func validateBallotMaturityBlocks(i interface{}) error { _, ok := i.(int64) if !ok { @@ -108,12 +109,3 @@ func validateBallotMaturityBlocks(i interface{}) error { return nil } - -func (p Params) GetAdminPolicyAccount(policyType Policy_Type) string { - for _, admin := range p.AdminPolicy { - if admin.PolicyType == policyType { - return admin.Address - } - } - return "" -} diff --git a/x/observer/types/params_test.go b/x/observer/types/params_test.go new file mode 100644 index 0000000000..3b7b177565 --- /dev/null +++ b/x/observer/types/params_test.go @@ -0,0 +1,74 @@ +package types + +import ( + "reflect" + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + paramtypes "github.com/cosmos/cosmos-sdk/x/params/types" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v2" +) + +func TestParamKeyTable(t *testing.T) { + kt := ParamKeyTable() + + ps := Params{} + for _, psp := range ps.ParamSetPairs() { + require.PanicsWithValue(t, "duplicate parameter key", func() { + kt.RegisterType(psp) + }) + } +} + +func TestParamSetPairs(t *testing.T) { + params := DefaultParams() + pairs := params.ParamSetPairs() + + require.Equal(t, 3, len(pairs), "The number of param set pairs should match the expected count") + + assertParamSetPair(t, pairs, KeyPrefix(ObserverParamsKey), ¶ms.ObserverParams, validateVotingThresholds) + assertParamSetPair(t, pairs, KeyPrefix(AdminPolicyParamsKey), ¶ms.AdminPolicy, validateAdminPolicy) + assertParamSetPair(t, pairs, KeyPrefix(BallotMaturityBlocksParamsKey), ¶ms.BallotMaturityBlocks, validateBallotMaturityBlocks) +} + +func assertParamSetPair(t *testing.T, pairs paramtypes.ParamSetPairs, key []byte, expectedValue interface{}, valFunc paramtypes.ValueValidatorFn) { + for _, pair := range pairs { + if string(pair.Key) == string(key) { + require.Equal(t, expectedValue, pair.Value, "Value does not match for key %s", string(key)) + + actualValFunc := pair.ValidatorFn + require.Equal(t, reflect.ValueOf(valFunc).Pointer(), reflect.ValueOf(actualValFunc).Pointer(), "Val func doesnt match for key %s", string(key)) + return + } + } + + t.Errorf("Key %s not found in ParamSetPairs", string(key)) +} + +func TestParamsString(t *testing.T) { + params := DefaultParams() + out, err := yaml.Marshal(params) + require.NoError(t, err) + require.Equal(t, string(out), params.String()) +} + +func TestValidateVotingThresholds(t *testing.T) { + require.Error(t, validateVotingThresholds("invalid")) + + params := DefaultParams() + require.NoError(t, validateVotingThresholds(params.ObserverParams)) + + params.ObserverParams[0].BallotThreshold = sdk.MustNewDecFromStr("1.1") + require.Error(t, validateVotingThresholds(params.ObserverParams)) +} + +func TestValidateAdminPolicy(t *testing.T) { + require.Error(t, validateAdminPolicy("invalid")) + require.NoError(t, validateAdminPolicy([]*Admin_Policy{})) +} + +func TestValidateBallotMaturityBlocks(t *testing.T) { + require.Error(t, validateBallotMaturityBlocks("invalid")) + require.NoError(t, validateBallotMaturityBlocks(int64(1))) +} diff --git a/x/observer/types/parsers_test.go b/x/observer/types/parsers_test.go new file mode 100644 index 0000000000..1c108fcee8 --- /dev/null +++ b/x/observer/types/parsers_test.go @@ -0,0 +1,98 @@ +package types_test + +import ( + "math/rand" + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/pkg/chains" + "github.com/zeta-chain/zetacore/testutil/sample" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestConvertReceiveStatusToVoteType(t *testing.T) { + tests := []struct { + name string + status chains.ReceiveStatus + expected types.VoteType + }{ + {"TestSuccessStatus", chains.ReceiveStatus_Success, types.VoteType_SuccessObservation}, + {"TestFailedStatus", chains.ReceiveStatus_Failed, types.VoteType_FailureObservation}, + {"TestDefaultStatus", chains.ReceiveStatus_Created, types.VoteType_NotYetVoted}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := types.ConvertReceiveStatusToVoteType(tt.status) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestParseStringToObservationType(t *testing.T) { + tests := []struct { + name string + observationType string + expected types.ObservationType + }{ + {"TestValidObservationType1", "EmptyObserverType", types.ObservationType(0)}, + {"TestValidObservationType1", "InBoundTx", types.ObservationType(1)}, + {"TestValidObservationType1", "OutBoundTx", types.ObservationType(2)}, + {"TestValidObservationType1", "TSSKeyGen", types.ObservationType(3)}, + {"TestValidObservationType1", "TSSKeySign", types.ObservationType(4)}, + {"TestInvalidObservationType", "InvalidType", types.ObservationType(0)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := types.ParseStringToObservationType(tt.observationType) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestGetOperatorAddressFromAccAddress(t *testing.T) { + tests := []struct { + name string + accAddr string + wantErr bool + }{ + {"TestValidAccAddress", sample.AccAddress(), false}, + {"TestInvalidAccAddress", "invalid", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := types.GetOperatorAddressFromAccAddress(tt.accAddr) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestGetAccAddressFromOperatorAddress(t *testing.T) { + // #nosec G404 test purpose - weak randomness is not an issue here + r := rand.New(rand.NewSource(1)) + tests := []struct { + name string + valAddress string + wantErr bool + }{ + {"TestValidValAddress", sample.ValAddress(r).String(), false}, + {"TestInvalidValAddress", "invalid", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := types.GetAccAddressFromOperatorAddress(tt.valAddress) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/x/observer/types/test_data.go b/x/observer/types/test_data.go deleted file mode 100644 index 3a43ca4306..0000000000 --- a/x/observer/types/test_data.go +++ /dev/null @@ -1,27 +0,0 @@ -package types - -import ( - sdk "github.com/cosmos/cosmos-sdk/types" -) - -const ( - AccountAddressPrefix = "zeta" -) - -var ( - AccountPubKeyPrefix = AccountAddressPrefix + "pub" - ValidatorAddressPrefix = AccountAddressPrefix + "valoper" - ValidatorPubKeyPrefix = AccountAddressPrefix + "valoperpub" - ConsNodeAddressPrefix = AccountAddressPrefix + "valcons" - ConsNodePubKeyPrefix = AccountAddressPrefix + "valconspub" -) - -func SetConfig(seal bool) { - config := sdk.GetConfig() - config.SetBech32PrefixForAccount(AccountAddressPrefix, AccountPubKeyPrefix) - config.SetBech32PrefixForValidator(ValidatorAddressPrefix, ValidatorPubKeyPrefix) - config.SetBech32PrefixForConsensusNode(ConsNodeAddressPrefix, ConsNodePubKeyPrefix) - if seal { - config.Seal() - } -} diff --git a/x/observer/types/utils_test.go b/x/observer/types/utils_test.go new file mode 100644 index 0000000000..ac42ea127c --- /dev/null +++ b/x/observer/types/utils_test.go @@ -0,0 +1,55 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/go-tss/blame" + "github.com/zeta-chain/zetacore/x/observer/types" +) + +func TestConvertNodes(t *testing.T) { + tests := []struct { + name string + input []blame.Node + expected []*types.Node + }{ + { + name: "TestEmptyInput", + input: []blame.Node{}, + expected: nil, + }, + { + name: "TestNilInput", + input: nil, + expected: nil, + }, + { + name: "TestSingleInput", + input: []blame.Node{ + {Pubkey: "pubkey1", BlameSignature: []byte("signature1"), BlameData: []byte("data1")}, + }, + expected: []*types.Node{ + {PubKey: "pubkey1", BlameSignature: []byte("signature1"), BlameData: []byte("data1")}, + }, + }, + { + name: "TestMultipleInputs", + input: []blame.Node{ + {Pubkey: "pubkey1", BlameSignature: []byte("signature1"), BlameData: []byte("data1")}, + {Pubkey: "pubkey2", BlameSignature: []byte("signature2"), BlameData: []byte("data2")}, + }, + expected: []*types.Node{ + {PubKey: "pubkey1", BlameSignature: []byte("signature1"), BlameData: []byte("data1")}, + {PubKey: "pubkey2", BlameSignature: []byte("signature2"), BlameData: []byte("data2")}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := types.ConvertNodes(tt.input) + require.Equal(t, tt.expected, result) + }) + } +}