diff --git a/x/observer/keeper/msg_server_add_observer.go b/x/observer/keeper/msg_server_add_observer.go index 11f7701dde..411a461197 100644 --- a/x/observer/keeper/msg_server_add_observer.go +++ b/x/observer/keeper/msg_server_add_observer.go @@ -52,7 +52,10 @@ func (k msgServer) AddObserver( return &types.MsgAddObserverResponse{}, nil } - k.AddObserverToSet(ctx, msg.ObserverAddress) + err = k.AddObserverToSet(ctx, msg.ObserverAddress) + if err != nil { + return &types.MsgAddObserverResponse{}, err + } observerSet, _ := k.GetObserverSet(ctx) k.SetLastObserverCount(ctx, &types.LastObserverCount{Count: observerSet.LenUint()}) diff --git a/x/observer/keeper/observer_set.go b/x/observer/keeper/observer_set.go index c8a22e0e0f..e0b28142b9 100644 --- a/x/observer/keeper/observer_set.go +++ b/x/observer/keeper/observer_set.go @@ -1,6 +1,7 @@ package keeper import ( + "cosmossdk.io/errors" "github.com/cosmos/cosmos-sdk/store/prefix" sdk "github.com/cosmos/cosmos-sdk/types" @@ -36,21 +37,21 @@ func (k Keeper) IsAddressPartOfObserverSet(ctx sdk.Context, address string) bool return false } -func (k Keeper) AddObserverToSet(ctx sdk.Context, address string) { +func (k Keeper) AddObserverToSet(ctx sdk.Context, address string) error { observerSet, found := k.GetObserverSet(ctx) if !found { k.SetObserverSet(ctx, types.ObserverSet{ ObserverList: []string{address}, }) - return - } - for _, addr := range observerSet.ObserverList { - if addr == address { - return - } + return nil } observerSet.ObserverList = append(observerSet.ObserverList, address) + err := observerSet.Validate() + if err != nil { + return err + } k.SetObserverSet(ctx, observerSet) + return nil } func (k Keeper) RemoveObserverFromSet(ctx sdk.Context, address string) { @@ -72,12 +73,22 @@ func (k Keeper) UpdateObserverAddress(ctx sdk.Context, oldObserverAddress, newOb if !found { return types.ErrObserverSetNotFound } + found = false for i, addr := range observerSet.ObserverList { if addr == oldObserverAddress { observerSet.ObserverList[i] = newObserverAddress - k.SetObserverSet(ctx, observerSet) - return nil + found = true + break } } - return types.ErrUpdateObserver + if !found { + return types.ErrObserverNotFound + } + + err := observerSet.Validate() + if err != nil { + return errors.Wrap(types.ErrUpdateObserver, err.Error()) + } + k.SetObserverSet(ctx, observerSet) + return nil } diff --git a/x/observer/keeper/observer_set_test.go b/x/observer/keeper/observer_set_test.go index 3ae4a99cf2..0718aea2b4 100644 --- a/x/observer/keeper/observer_set_test.go +++ b/x/observer/keeper/observer_set_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/x/observer/types" keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/testutil/sample" @@ -39,7 +40,8 @@ func TestKeeper_AddObserverToSet(t *testing.T) { os := sample.ObserverSet(10) k.SetObserverSet(ctx, os) newObserver := sample.AccAddress() - k.AddObserverToSet(ctx, newObserver) + err := k.AddObserverToSet(ctx, newObserver) + require.NoError(t, err) require.True(t, k.IsAddressPartOfObserverSet(ctx, newObserver)) require.False(t, k.IsAddressPartOfObserverSet(ctx, sample.AccAddress())) osNew, found := k.GetObserverSet(ctx) @@ -95,6 +97,20 @@ func TestKeeper_UpdateObserverAddress(t *testing.T) { require.True(t, found) require.Equal(t, newObserverAddress, observerSet.ObserverList[len(observerSet.ObserverList)-1]) }) + t.Run("unable to update observer list if the new list is not valid", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + oldObserverAddress := sample.AccAddress() + newObserverAddress := sample.AccAddress() + observerSet := sample.ObserverSet(10) + observerSet.ObserverList = append(observerSet.ObserverList, []string{oldObserverAddress, newObserverAddress}...) + + err := k.UpdateObserverAddress(ctx, oldObserverAddress, newObserverAddress) + require.ErrorIs(t, err, types.ErrObserverSetNotFound) + k.SetObserverSet(ctx, observerSet) + + err = k.UpdateObserverAddress(ctx, oldObserverAddress, newObserverAddress) + require.ErrorContains(t, err, types.ErrDuplicateObserver.Error()) + }) t.Run("should error if observer address not found", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) oldObserverAddress := sample.AccAddress() @@ -103,7 +119,7 @@ func TestKeeper_UpdateObserverAddress(t *testing.T) { observerSet.ObserverList = append(observerSet.ObserverList, oldObserverAddress) k.SetObserverSet(ctx, observerSet) err := k.UpdateObserverAddress(ctx, sample.AccAddress(), newObserverAddress) - require.Error(t, err) + require.ErrorIs(t, err, types.ErrObserverNotFound) }) t.Run("update observer address long observerList", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) diff --git a/x/observer/types/errors.go b/x/observer/types/errors.go index 6485e613ed..a997267657 100644 --- a/x/observer/types/errors.go +++ b/x/observer/types/errors.go @@ -49,4 +49,6 @@ var ( ErrInboundDisabled = errorsmod.Register(ModuleName, 1132, "inbound tx processing is disabled") ErrInvalidZetaCoinTypes = errorsmod.Register(ModuleName, 1133, "invalid zeta coin types") ErrNotObserver = errorsmod.Register(ModuleName, 1134, "sender is not an observer") + ErrDuplicateObserver = errorsmod.Register(ModuleName, 1135, "observer already exists") + ErrObserverNotFound = errorsmod.Register(ModuleName, 1136, "observer not found") ) diff --git a/x/observer/types/observer_set.go b/x/observer/types/observer_set.go index ffa2d1c05a..0e68cb5ec0 100644 --- a/x/observer/types/observer_set.go +++ b/x/observer/types/observer_set.go @@ -1,6 +1,7 @@ package types import ( + "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/zeta-chain/zetacore/pkg/chains" @@ -22,6 +23,14 @@ func (m *ObserverSet) Validate() error { return err } } + // Check for duplicates + observers := make(map[string]bool) + for _, observerAddress := range m.ObserverList { + if _, ok := observers[observerAddress]; ok { + return errors.Wrapf(ErrDuplicateObserver, "observer %s", observerAddress) + } + observers[observerAddress] = true + } return nil } diff --git a/x/observer/types/observer_set_test.go b/x/observer/types/observer_set_test.go index 69a9a19f96..8344757b18 100644 --- a/x/observer/types/observer_set_test.go +++ b/x/observer/types/observer_set_test.go @@ -10,17 +10,39 @@ import ( "github.com/zeta-chain/zetacore/x/observer/types" ) -func TestObserverSet(t *testing.T) { - observerSet := sample.ObserverSet(4) +func TestObserverSet_Validate(t *testing.T) { + observer1Address := sample.AccAddress() + tt := []struct { + name string + observer types.ObserverSet + wantErr require.ErrorAssertionFunc + }{ + { + name: "observer set with duplicate observer", + observer: types.ObserverSet{ObserverList: []string{observer1Address, observer1Address}}, + wantErr: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorIs(t, err, types.ErrDuplicateObserver) + }, + }, + { + name: "observer set with invalid observer", + observer: types.ObserverSet{ObserverList: []string{"invalid"}}, + wantErr: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "decoding bech32 failed") + }, + }, + { + name: "observer set with valid observer", + observer: types.ObserverSet{ObserverList: []string{observer1Address}}, + wantErr: require.NoError, + }, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + tc.wantErr(t, tc.observer.Validate()) + }) - require.Equal(t, int(4), observerSet.Len()) - require.Equal(t, uint64(4), observerSet.LenUint()) - err := observerSet.Validate() - require.NoError(t, err) - - observerSet.ObserverList[0] = "invalid" - err = observerSet.Validate() - require.Error(t, err) + } } func TestCheckReceiveStatus(t *testing.T) {