Skip to content

Commit

Permalink
Add hooks tests
Browse files Browse the repository at this point in the history
  • Loading branch information
skosito committed Mar 29, 2024
1 parent 7721694 commit 6ef3710
Showing 1 changed file with 168 additions and 1 deletion.
169 changes: 168 additions & 1 deletion x/observer/keeper/hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"math/rand"
"testing"

sdk "github.com/cosmos/cosmos-sdk/types"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"
"github.com/stretchr/testify/require"
keepertest "github.com/zeta-chain/zetacore/testutil/keeper"
"github.com/zeta-chain/zetacore/testutil/sample"
Expand Down Expand Up @@ -34,11 +36,176 @@ func TestKeeper_AfterValidatorRemoved(t *testing.T) {
ObserverList: []string{accAddress.String()},
}
k.SetObserverSet(ctx, os)
hooks := k.Hooks()

hooks := k.Hooks()
hooks.AfterValidatorRemoved(ctx, nil, valAddr)

os, found := k.GetObserverSet(ctx)
require.True(t, found)
require.Empty(t, os.ObserverList)
}

func TestKeeper_AfterValidatorBeginUnbonding(t *testing.T) {
k, ctx, _, _ := keepertest.ObserverKeeper(t)

r := rand.New(rand.NewSource(9))
validator := sample.Validator(t, r)
validator.DelegatorShares = sdk.NewDec(100)
k.GetStakingKeeper().SetValidator(ctx, validator)
accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress)
require.NoError(t, err)

k.GetStakingKeeper().SetDelegation(ctx, stakingtypes.Delegation{
DelegatorAddress: accAddressOfValidator.String(),
ValidatorAddress: validator.GetOperator().String(),
Shares: sdk.NewDec(10),
})

k.SetObserverSet(ctx, types.ObserverSet{
ObserverList: []string{accAddressOfValidator.String()},
})

hooks := k.Hooks()
hooks.AfterValidatorBeginUnbonding(ctx, nil, validator.GetOperator())
require.NoError(t, err)

os, found := k.GetObserverSet(ctx)
require.True(t, found)
require.Empty(t, os.ObserverList)
}

func TestKeeper_AfterDelegationModified(t *testing.T) {
t.Run("should not clean observer if not self delegation", func(t *testing.T) {
k, ctx, _, _ := keepertest.ObserverKeeper(t)

r := rand.New(rand.NewSource(9))
validator := sample.Validator(t, r)
validator.DelegatorShares = sdk.NewDec(100)
k.GetStakingKeeper().SetValidator(ctx, validator)
accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress)
require.NoError(t, err)

k.GetStakingKeeper().SetDelegation(ctx, stakingtypes.Delegation{
DelegatorAddress: accAddressOfValidator.String(),
ValidatorAddress: validator.GetOperator().String(),
Shares: sdk.NewDec(10),
})

k.SetObserverSet(ctx, types.ObserverSet{
ObserverList: []string{accAddressOfValidator.String()},
})

hooks := k.Hooks()
err = hooks.AfterDelegationModified(ctx, sdk.AccAddress(sample.AccAddress()), validator.GetOperator())
require.NoError(t, err)

os, found := k.GetObserverSet(ctx)
require.True(t, found)
require.Equal(t, 1, len(os.ObserverList))
require.Equal(t, accAddressOfValidator.String(), os.ObserverList[0])
})

t.Run("should clean observer if self delegation", func(t *testing.T) {
k, ctx, _, _ := keepertest.ObserverKeeper(t)

r := rand.New(rand.NewSource(9))
validator := sample.Validator(t, r)
validator.DelegatorShares = sdk.NewDec(100)
k.GetStakingKeeper().SetValidator(ctx, validator)
accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress)
require.NoError(t, err)

k.GetStakingKeeper().SetDelegation(ctx, stakingtypes.Delegation{
DelegatorAddress: accAddressOfValidator.String(),
ValidatorAddress: validator.GetOperator().String(),
Shares: sdk.NewDec(10),
})

k.SetObserverSet(ctx, types.ObserverSet{
ObserverList: []string{accAddressOfValidator.String()},
})

hooks := k.Hooks()
hooks.AfterDelegationModified(ctx, accAddressOfValidator, validator.GetOperator())
require.NoError(t, err)

os, found := k.GetObserverSet(ctx)
require.True(t, found)
require.Empty(t, os.ObserverList)
})
}

func TestKeeper_BeforeValidatorSlashed(t *testing.T) {
t.Run("should error if validator not found", func(t *testing.T) {
k, ctx, _, _ := keepertest.ObserverKeeper(t)

r := rand.New(rand.NewSource(9))
validator := sample.Validator(t, r)

hooks := k.Hooks()
err := hooks.BeforeValidatorSlashed(ctx, validator.GetOperator(), sdk.NewDec(1))
require.Error(t, err)
})

t.Run("should not error if observer set not found", func(t *testing.T) {
k, ctx, _, _ := keepertest.ObserverKeeper(t)

r := rand.New(rand.NewSource(9))
validator := sample.Validator(t, r)
validator.DelegatorShares = sdk.NewDec(100)
k.GetStakingKeeper().SetValidator(ctx, validator)

hooks := k.Hooks()
err := hooks.BeforeValidatorSlashed(ctx, validator.GetOperator(), sdk.NewDec(1))
require.NoError(t, err)
})

t.Run("should remove from observer set", func(t *testing.T) {
k, ctx, _, _ := keepertest.ObserverKeeper(t)

r := rand.New(rand.NewSource(9))
validator := sample.Validator(t, r)
validator.DelegatorShares = sdk.NewDec(100)
validator.Tokens = sdk.NewInt(100)
k.GetStakingKeeper().SetValidator(ctx, validator)
accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress)
require.NoError(t, err)

k.SetObserverSet(ctx, types.ObserverSet{
ObserverList: []string{accAddressOfValidator.String()},
})

hooks := k.Hooks()
err = hooks.BeforeValidatorSlashed(ctx, validator.GetOperator(), sdk.MustNewDecFromStr("0.1"))
require.NoError(t, err)

os, found := k.GetObserverSet(ctx)
require.True(t, found)
require.Empty(t, os.ObserverList)
})

t.Run("should not remove from observer set", func(t *testing.T) {
k, ctx, _, _ := keepertest.ObserverKeeper(t)

r := rand.New(rand.NewSource(9))
validator := sample.Validator(t, r)
validator.DelegatorShares = sdk.NewDec(100)
validator.Tokens = sdk.NewInt(1000000000000000000)
k.GetStakingKeeper().SetValidator(ctx, validator)
accAddressOfValidator, err := types.GetAccAddressFromOperatorAddress(validator.OperatorAddress)
require.NoError(t, err)

k.SetObserverSet(ctx, types.ObserverSet{
ObserverList: []string{accAddressOfValidator.String()},
})

hooks := k.Hooks()
err = hooks.BeforeValidatorSlashed(ctx, validator.GetOperator(), sdk.MustNewDecFromStr("0"))
require.NoError(t, err)

os, found := k.GetObserverSet(ctx)
require.True(t, found)
require.Equal(t, 1, len(os.ObserverList))
require.Equal(t, accAddressOfValidator.String(), os.ObserverList[0])
})
}

0 comments on commit 6ef3710

Please sign in to comment.