Skip to content

Commit

Permalink
add duplicate check to validate observer set
Browse files Browse the repository at this point in the history
  • Loading branch information
kingpinXD committed Aug 7, 2024
1 parent d8a23e3 commit 415d354
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 23 deletions.
5 changes: 4 additions & 1 deletion x/observer/keeper/msg_server_add_observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ func (k msgServer) AddObserver(
return &types.MsgAddObserverResponse{}, nil
}

k.AddObserverToSet(ctx, msg.ObserverAddress)
err = k.AddObserverToSet(ctx, msg.ObserverAddress)
if err != nil {
return &types.MsgAddObserverResponse{}, err
}
observerSet, _ := k.GetObserverSet(ctx)

k.SetLastObserverCount(ctx, &types.LastObserverCount{Count: observerSet.LenUint()})
Expand Down
31 changes: 21 additions & 10 deletions x/observer/keeper/observer_set.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package keeper

import (
"cosmossdk.io/errors"
"github.com/cosmos/cosmos-sdk/store/prefix"
sdk "github.com/cosmos/cosmos-sdk/types"

Expand Down Expand Up @@ -36,21 +37,21 @@ func (k Keeper) IsAddressPartOfObserverSet(ctx sdk.Context, address string) bool
return false
}

func (k Keeper) AddObserverToSet(ctx sdk.Context, address string) {
func (k Keeper) AddObserverToSet(ctx sdk.Context, address string) error {
observerSet, found := k.GetObserverSet(ctx)
if !found {
k.SetObserverSet(ctx, types.ObserverSet{
ObserverList: []string{address},
})
return
}
for _, addr := range observerSet.ObserverList {
if addr == address {
return
}
return nil
}
observerSet.ObserverList = append(observerSet.ObserverList, address)
err := observerSet.Validate()
if err != nil {
return err
}
k.SetObserverSet(ctx, observerSet)
return nil
}

func (k Keeper) RemoveObserverFromSet(ctx sdk.Context, address string) {
Expand All @@ -72,12 +73,22 @@ func (k Keeper) UpdateObserverAddress(ctx sdk.Context, oldObserverAddress, newOb
if !found {
return types.ErrObserverSetNotFound
}
found = false
for i, addr := range observerSet.ObserverList {
if addr == oldObserverAddress {
observerSet.ObserverList[i] = newObserverAddress
k.SetObserverSet(ctx, observerSet)
return nil
found = true
break
}
}
return types.ErrUpdateObserver
if !found {
return types.ErrObserverNotFound
}

err := observerSet.Validate()
if err != nil {
return errors.Wrap(types.ErrUpdateObserver, err.Error())
}
k.SetObserverSet(ctx, observerSet)
return nil
}
20 changes: 18 additions & 2 deletions x/observer/keeper/observer_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

"github.com/stretchr/testify/require"
"github.com/zeta-chain/zetacore/x/observer/types"

keepertest "github.com/zeta-chain/zetacore/testutil/keeper"
"github.com/zeta-chain/zetacore/testutil/sample"
Expand Down Expand Up @@ -39,7 +40,8 @@ func TestKeeper_AddObserverToSet(t *testing.T) {
os := sample.ObserverSet(10)
k.SetObserverSet(ctx, os)
newObserver := sample.AccAddress()
k.AddObserverToSet(ctx, newObserver)
err := k.AddObserverToSet(ctx, newObserver)
require.NoError(t, err)
require.True(t, k.IsAddressPartOfObserverSet(ctx, newObserver))
require.False(t, k.IsAddressPartOfObserverSet(ctx, sample.AccAddress()))
osNew, found := k.GetObserverSet(ctx)
Expand Down Expand Up @@ -95,6 +97,20 @@ func TestKeeper_UpdateObserverAddress(t *testing.T) {
require.True(t, found)
require.Equal(t, newObserverAddress, observerSet.ObserverList[len(observerSet.ObserverList)-1])
})
t.Run("unable to update observer list if the new list is not valid", func(t *testing.T) {
k, ctx, _, _ := keepertest.ObserverKeeper(t)
oldObserverAddress := sample.AccAddress()
newObserverAddress := sample.AccAddress()
observerSet := sample.ObserverSet(10)
observerSet.ObserverList = append(observerSet.ObserverList, []string{oldObserverAddress, newObserverAddress}...)

err := k.UpdateObserverAddress(ctx, oldObserverAddress, newObserverAddress)
require.ErrorIs(t, err, types.ErrObserverSetNotFound)
k.SetObserverSet(ctx, observerSet)

err = k.UpdateObserverAddress(ctx, oldObserverAddress, newObserverAddress)
require.ErrorContains(t, err, types.ErrDuplicateObserver.Error())
})
t.Run("should error if observer address not found", func(t *testing.T) {
k, ctx, _, _ := keepertest.ObserverKeeper(t)
oldObserverAddress := sample.AccAddress()
Expand All @@ -103,7 +119,7 @@ func TestKeeper_UpdateObserverAddress(t *testing.T) {
observerSet.ObserverList = append(observerSet.ObserverList, oldObserverAddress)
k.SetObserverSet(ctx, observerSet)
err := k.UpdateObserverAddress(ctx, sample.AccAddress(), newObserverAddress)
require.Error(t, err)
require.ErrorIs(t, err, types.ErrObserverNotFound)
})
t.Run("update observer address long observerList", func(t *testing.T) {
k, ctx, _, _ := keepertest.ObserverKeeper(t)
Expand Down
2 changes: 2 additions & 0 deletions x/observer/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ var (
ErrInboundDisabled = errorsmod.Register(ModuleName, 1132, "inbound tx processing is disabled")
ErrInvalidZetaCoinTypes = errorsmod.Register(ModuleName, 1133, "invalid zeta coin types")
ErrNotObserver = errorsmod.Register(ModuleName, 1134, "sender is not an observer")
ErrDuplicateObserver = errorsmod.Register(ModuleName, 1135, "observer already exists")
ErrObserverNotFound = errorsmod.Register(ModuleName, 1136, "observer not found")
)
9 changes: 9 additions & 0 deletions x/observer/types/observer_set.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package types

import (
"cosmossdk.io/errors"
sdk "github.com/cosmos/cosmos-sdk/types"

"github.com/zeta-chain/zetacore/pkg/chains"
Expand All @@ -22,6 +23,14 @@ func (m *ObserverSet) Validate() error {
return err
}
}
// Check for duplicates
observers := make(map[string]bool)
for _, observerAddress := range m.ObserverList {
if _, ok := observers[observerAddress]; ok {
return errors.Wrapf(ErrDuplicateObserver, "observer %s", observerAddress)
}
observers[observerAddress] = true
}
return nil
}

Expand Down
42 changes: 32 additions & 10 deletions x/observer/types/observer_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,39 @@ import (
"github.com/zeta-chain/zetacore/x/observer/types"
)

func TestObserverSet(t *testing.T) {
observerSet := sample.ObserverSet(4)
func TestObserverSet_Validate(t *testing.T) {
observer1Address := sample.AccAddress()
tt := []struct {
name string
observer types.ObserverSet
wantErr require.ErrorAssertionFunc
}{
{
name: "observer set with duplicate observer",
observer: types.ObserverSet{ObserverList: []string{observer1Address, observer1Address}},
wantErr: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorIs(t, err, types.ErrDuplicateObserver)
},
},
{
name: "observer set with invalid observer",
observer: types.ObserverSet{ObserverList: []string{"invalid"}},
wantErr: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "decoding bech32 failed")
},
},
{
name: "observer set with valid observer",
observer: types.ObserverSet{ObserverList: []string{observer1Address}},
wantErr: require.NoError,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
tc.wantErr(t, tc.observer.Validate())
})

require.Equal(t, int(4), observerSet.Len())
require.Equal(t, uint64(4), observerSet.LenUint())
err := observerSet.Validate()
require.NoError(t, err)

observerSet.ObserverList[0] = "invalid"
err = observerSet.Validate()
require.Error(t, err)
}
}

func TestCheckReceiveStatus(t *testing.T) {
Expand Down

0 comments on commit 415d354

Please sign in to comment.