diff --git a/x/observer/keeper/ballot_test.go b/x/observer/keeper/ballot_test.go index cea8fd47ad..055a49b755 100644 --- a/x/observer/keeper/ballot_test.go +++ b/x/observer/keeper/ballot_test.go @@ -52,6 +52,27 @@ func TestKeeper_GetBallotList(t *testing.T) { require.Equal(t, identifier, list.BallotsIndexList[0]) } +func TestKeeper_GetAllBallots(t *testing.T) { + k, ctx := SetupKeeper(t) + identifier := "0x9ea007f0f60e32d58577a8cf25678942d2b10791c2a34f48e237b76a7e998e4d" + b := &types.Ballot{ + Index: "", + BallotIdentifier: identifier, + VoterList: nil, + ObservationType: 0, + BallotThreshold: sdk.Dec{}, + BallotStatus: 0, + BallotCreationHeight: 1, + } + ballots := k.GetAllBallots(ctx) + require.Empty(t, ballots) + + k.SetBallot(ctx, b) + ballots = k.GetAllBallots(ctx) + require.Equal(t, 1, len(ballots)) + require.Equal(t, b, ballots[0]) +} + func TestKeeper_GetMaturedBallotList(t *testing.T) { k, ctx := SetupKeeper(t) identifier := "0x9ea007f0f60e32d58577a8cf25678942d2b10791c2a34f48e237b76a7e998e4d" diff --git a/x/observer/keeper/grpc_query_tss_test.go b/x/observer/keeper/grpc_query_tss_test.go new file mode 100644 index 0000000000..e0df837258 --- /dev/null +++ b/x/observer/keeper/grpc_query_tss_test.go @@ -0,0 +1,101 @@ +package keeper_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + keepertest "github.com/zeta-chain/zetacore/testutil/keeper" + "github.com/zeta-chain/zetacore/testutil/sample" + "github.com/zeta-chain/zetacore/x/observer/types" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestTSSQuerySingle(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + tss := sample.Tss() + wctx := sdk.WrapSDKContext(ctx) + + for _, tc := range []struct { + desc string + request *types.QueryGetTSSRequest + response *types.QueryGetTSSResponse + skipSettingTss bool + err error + }{ + { + desc: "Skip setting tss", + request: &types.QueryGetTSSRequest{}, + skipSettingTss: true, + err: status.Error(codes.InvalidArgument, "not found"), + }, + { + desc: "InvalidRequest", + err: status.Error(codes.InvalidArgument, "invalid request"), + }, + { + desc: "Should return tss", + request: &types.QueryGetTSSRequest{}, + response: &types.QueryGetTSSResponse{TSS: tss}, + }, + } { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + if !tc.skipSettingTss { + k.SetTSS(ctx, tss) + } + response, err := k.TSS(wctx, tc.request) + if tc.err != nil { + require.ErrorIs(t, err, tc.err) + } else { + require.Equal(t, tc.response, response) + } + }) + } +} + +func TestTSSQueryHistory(t *testing.T) { + keeper, ctx, _, _ := keepertest.ObserverKeeper(t) + wctx := sdk.WrapSDKContext(ctx) + for _, tc := range []struct { + desc string + tssCount int + foundPrevious bool + err error + }{ + { + desc: "1 Tss addresses", + tssCount: 1, + foundPrevious: false, + err: nil, + }, + { + desc: "10 Tss addresses", + tssCount: 10, + foundPrevious: true, + err: nil, + }, + } { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + tssList := sample.TssList(tc.tssCount) + for _, tss := range tssList { + keeper.SetTSS(ctx, tss) + keeper.SetTSSHistory(ctx, tss) + } + request := &types.QueryTssHistoryRequest{} + response, err := keeper.TssHistory(wctx, request) + if tc.err != nil { + require.ErrorIs(t, err, tc.err) + } else { + require.Equal(t, len(tssList), len(response.TssList)) + prevTss, found := keeper.GetPreviousTSS(ctx) + require.Equal(t, tc.foundPrevious, found) + if found { + require.Equal(t, tssList[len(tssList)-2], prevTss) + } + } + }) + } +} diff --git a/x/observer/keeper/tss_test.go b/x/observer/keeper/tss_test.go index 7574e05726..a1f1ca6ddc 100644 --- a/x/observer/keeper/tss_test.go +++ b/x/observer/keeper/tss_test.go @@ -9,14 +9,10 @@ import ( "github.com/stretchr/testify/require" keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/testutil/sample" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - sdk "github.com/cosmos/cosmos-sdk/types" "github.com/zeta-chain/zetacore/x/observer/types" ) -func TestTSSGet(t *testing.T) { +func TestKeeper_GetTSS(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) tss := sample.Tss() k.SetTSS(ctx, tss) @@ -25,7 +21,7 @@ func TestTSSGet(t *testing.T) { require.Equal(t, tss, tssQueried) } -func TestTSSRemove(t *testing.T) { +func TestKeeper_RemoveTSS(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) tss := sample.Tss() k.SetTSS(ctx, tss) @@ -34,83 +30,34 @@ func TestTSSRemove(t *testing.T) { require.False(t, found) } -func TestTSSQuerySingle(t *testing.T) { +func TestKeeper_CheckIfTssPubkeyHasBeenGenerated(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) - wctx := sdk.WrapSDKContext(ctx) - //msgs := createTSS(keeper, ctx, 1) tss := sample.Tss() - k.SetTSS(ctx, tss) - for _, tc := range []struct { - desc string - request *types.QueryGetTSSRequest - response *types.QueryGetTSSResponse - err error - }{ - { - desc: "First", - request: &types.QueryGetTSSRequest{}, - response: &types.QueryGetTSSResponse{TSS: tss}, - }, - { - desc: "InvalidRequest", - err: status.Error(codes.InvalidArgument, "invalid request"), - }, - } { - tc := tc - t.Run(tc.desc, func(t *testing.T) { - response, err := k.TSS(wctx, tc.request) - if tc.err != nil { - require.ErrorIs(t, err, tc.err) - } else { - require.Equal(t, tc.response, response) - } - }) - } + + generated, found := k.CheckIfTssPubkeyHasBeenGenerated(ctx, tss.TssPubkey) + require.False(t, found) + require.Equal(t, types.TSS{}, generated) + + k.AppendTss(ctx, tss) + + generated, found = k.CheckIfTssPubkeyHasBeenGenerated(ctx, tss.TssPubkey) + require.True(t, found) + require.Equal(t, tss, generated) } -func TestTSSQueryHistory(t *testing.T) { - keeper, ctx, _, _ := keepertest.ObserverKeeper(t) - wctx := sdk.WrapSDKContext(ctx) - for _, tc := range []struct { - desc string - tssCount int - foundPrevious bool - err error - }{ - { - desc: "1 Tss addresses", - tssCount: 1, - foundPrevious: false, - err: nil, - }, - { - desc: "10 Tss addresses", - tssCount: 10, - foundPrevious: true, - err: nil, - }, - } { - tc := tc - t.Run(tc.desc, func(t *testing.T) { - tssList := sample.TssList(tc.tssCount) - for _, tss := range tssList { - keeper.SetTSS(ctx, tss) - keeper.SetTSSHistory(ctx, tss) - } - request := &types.QueryTssHistoryRequest{} - response, err := keeper.TssHistory(wctx, request) - if tc.err != nil { - require.ErrorIs(t, err, tc.err) - } else { - require.Equal(t, len(tssList), len(response.TssList)) - prevTss, found := keeper.GetPreviousTSS(ctx) - require.Equal(t, tc.foundPrevious, found) - if found { - require.Equal(t, tssList[len(tssList)-2], prevTss) - } - } - }) +func TestKeeper_GetHistoricalTssByFinalizedHeight(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + tssList := sample.TssList(100) + r := rand.Intn((len(tssList)-1)-0) + 0 + _, found := k.GetHistoricalTssByFinalizedHeight(ctx, tssList[r].FinalizedZetaHeight) + require.False(t, found) + + for _, tss := range tssList { + k.SetTSSHistory(ctx, tss) } + tss, found := k.GetHistoricalTssByFinalizedHeight(ctx, tssList[r].FinalizedZetaHeight) + require.True(t, found) + require.Equal(t, tssList[r], tss) } func TestKeeper_TssHistory(t *testing.T) { @@ -165,15 +112,4 @@ func TestKeeper_TssHistory(t *testing.T) { }) require.Equal(t, tssList, rst) }) - t.Run("Get historical TSS", func(t *testing.T) { - k, ctx, _, _ := keepertest.ObserverKeeper(t) - tssList := sample.TssList(100) - for _, tss := range tssList { - k.SetTSSHistory(ctx, tss) - } - r := rand.Intn((len(tssList)-1)-0) + 0 - tss, found := k.GetHistoricalTssByFinalizedHeight(ctx, tssList[r].FinalizedZetaHeight) - require.True(t, found) - require.Equal(t, tssList[r], tss) - }) }