From c2c0409c34b66c38c304b64fcc0680f9f97353ef Mon Sep 17 00:00:00 2001 From: lumtis Date: Wed, 17 Apr 2024 13:14:22 +0200 Subject: [PATCH] use sub function for get required group --- .../msg_server_update_verification_flags.go | 7 +-- .../message_update_verification_flags.go | 11 ++++ .../message_update_verification_flags_test.go | 56 +++++++++++++++++++ 3 files changed, 68 insertions(+), 6 deletions(-) diff --git a/x/lightclient/keeper/msg_server_update_verification_flags.go b/x/lightclient/keeper/msg_server_update_verification_flags.go index 5bdc3794f2..3330522e5e 100644 --- a/x/lightclient/keeper/msg_server_update_verification_flags.go +++ b/x/lightclient/keeper/msg_server_update_verification_flags.go @@ -17,13 +17,8 @@ func (k msgServer) UpdateVerificationFlags(goCtx context.Context, msg *types.Msg ) { ctx := sdk.UnwrapSDKContext(goCtx) - requiredGroup := authoritytypes.PolicyType_groupEmergency - if msg.VerificationFlags.EthTypeChainEnabled || msg.VerificationFlags.BtcTypeChainEnabled { - requiredGroup = authoritytypes.PolicyType_groupOperational - } - // check permission - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, requiredGroup) { + if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, msg.GetRequireGroup()) { return &types.MsgUpdateVerificationFlagsResponse{}, authoritytypes.ErrUnauthorized } diff --git a/x/lightclient/types/message_update_verification_flags.go b/x/lightclient/types/message_update_verification_flags.go index 82ab8018c7..a180596e40 100644 --- a/x/lightclient/types/message_update_verification_flags.go +++ b/x/lightclient/types/message_update_verification_flags.go @@ -4,6 +4,7 @@ import ( cosmoserrors "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" ) const ( @@ -49,3 +50,13 @@ func (msg *MsgUpdateVerificationFlags) ValidateBasic() error { } return nil } + +// GetRequireGroup returns the required group to execute the message +func (msg *MsgUpdateVerificationFlags) GetRequireGroup() authoritytypes.PolicyType { + requiredGroup := authoritytypes.PolicyType_groupEmergency + if msg.VerificationFlags.EthTypeChainEnabled || msg.VerificationFlags.BtcTypeChainEnabled { + requiredGroup = authoritytypes.PolicyType_groupOperational + } + + return requiredGroup +} diff --git a/x/lightclient/types/message_update_verification_flags_test.go b/x/lightclient/types/message_update_verification_flags_test.go index dfaa15c0ef..0b667ee08d 100644 --- a/x/lightclient/types/message_update_verification_flags_test.go +++ b/x/lightclient/types/message_update_verification_flags_test.go @@ -7,6 +7,7 @@ import ( sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/stretchr/testify/require" "github.com/zeta-chain/zetacore/testutil/sample" + authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" "github.com/zeta-chain/zetacore/x/lightclient/types" ) @@ -115,3 +116,58 @@ func TestMsgUpdateVerificationFlags_GetSignBytes(t *testing.T) { msg.GetSignBytes() }) } + +func TestMsgUpdateVerificationFlags_GetRequireGroup(t *testing.T) { + tests := []struct { + name string + msg types.MsgUpdateVerificationFlags + want authoritytypes.PolicyType + }{ + { + name: "groupEmergency", + msg: types.MsgUpdateVerificationFlags{ + VerificationFlags: types.VerificationFlags{ + EthTypeChainEnabled: false, + BtcTypeChainEnabled: false, + }, + }, + want: authoritytypes.PolicyType_groupEmergency, + }, + { + name: "groupOperational", + msg: types.MsgUpdateVerificationFlags{ + VerificationFlags: types.VerificationFlags{ + EthTypeChainEnabled: true, + BtcTypeChainEnabled: false, + }, + }, + want: authoritytypes.PolicyType_groupOperational, + }, + { + name: "groupOperational", + msg: types.MsgUpdateVerificationFlags{ + VerificationFlags: types.VerificationFlags{ + EthTypeChainEnabled: false, + BtcTypeChainEnabled: true, + }, + }, + want: authoritytypes.PolicyType_groupOperational, + }, + { + name: "groupOperational", + msg: types.MsgUpdateVerificationFlags{ + VerificationFlags: types.VerificationFlags{ + EthTypeChainEnabled: true, + BtcTypeChainEnabled: true, + }, + }, + want: authoritytypes.PolicyType_groupOperational, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.msg.GetRequireGroup() + require.Equal(t, tt.want, got) + }) + } +}