diff --git a/changelog.md b/changelog.md index 47ab793048..0e08931d37 100644 --- a/changelog.md +++ b/changelog.md @@ -23,6 +23,7 @@ * [2291](https://github.com/zeta-chain/node/pull/2291) - initialize cctx gateway interface * [2289](https://github.com/zeta-chain/node/pull/2289) - add an authorization list to keep track of all authorizations on the chain * [2305](https://github.com/zeta-chain/node/pull/2305) - add new messages `MsgAddAuthorization` and `MsgRemoveAuthorization` that can be used to update the authorization list +* [2313](https://github.com/zeta-chain/node/pull/2313) - add `CheckAuthorization` function to replace the `IsAuthorized` function. The new function uses the authorization list to verify the signer's authorization. ### Refactor diff --git a/testutil/sample/authority.go b/testutil/sample/authority.go index c7e3b7e6b8..be9ca2c555 100644 --- a/testutil/sample/authority.go +++ b/testutil/sample/authority.go @@ -3,6 +3,8 @@ package sample import ( "fmt" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/zeta-chain/zetacore/pkg/chains" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" ) @@ -65,3 +67,19 @@ func Authorization() authoritytypes.Authorization { AuthorizedPolicy: authoritytypes.PolicyType_groupOperational, } } + +// MultipleSignerMessage is a sample message which has two signers instead of one. This is used to test cases when we have checks for number of signers such as authorized transactions. +type MultipleSignerMessage struct{} + +var _ sdk.Msg = &MultipleSignerMessage{} + +func (m *MultipleSignerMessage) Reset() {} +func (m *MultipleSignerMessage) String() string { return "MultipleSignerMessage" } +func (m *MultipleSignerMessage) ProtoMessage() {} +func (m *MultipleSignerMessage) ValidateBasic() error { return nil } +func (m *MultipleSignerMessage) GetSigners() []sdk.AccAddress { + return []sdk.AccAddress{ + sdk.MustAccAddressFromBech32(AccAddress()), + sdk.MustAccAddressFromBech32(AccAddress()), + } +} diff --git a/x/authority/keeper/authorization_list.go b/x/authority/keeper/authorization_list.go index 0e7fba9163..24255bd089 100644 --- a/x/authority/keeper/authorization_list.go +++ b/x/authority/keeper/authorization_list.go @@ -1,6 +1,9 @@ package keeper import ( + "fmt" + + "cosmossdk.io/errors" "github.com/cosmos/cosmos-sdk/store/prefix" sdk "github.com/cosmos/cosmos-sdk/types" @@ -24,3 +27,49 @@ func (k Keeper) GetAuthorizationList(ctx sdk.Context) (val types.AuthorizationLi k.cdc.MustUnmarshal(b, &val) return val, true } + +// IsAuthorized checks if the address is authorized for the given policy type +func (k Keeper) IsAuthorized(ctx sdk.Context, address string, policyType types.PolicyType) bool { + policies, found := k.GetPolicies(ctx) + if !found { + return false + } + for _, policy := range policies.Items { + if policy.Address == address && policy.PolicyType == policyType { + return true + } + } + return false +} + +// CheckAuthorization checks if the signer is authorized to sign the message +// It uses both the authorization list and the policies to check if the signer is authorized +func (k Keeper) CheckAuthorization(ctx sdk.Context, msg sdk.Msg) error { + // Policy transactions must have only one signer + if len(msg.GetSigners()) != 1 { + return errors.Wrapf(types.ErrSigners, "msg: %v", sdk.MsgTypeURL(msg)) + } + + signer := msg.GetSigners()[0].String() + msgURL := sdk.MsgTypeURL(msg) + + authorizationsList, found := k.GetAuthorizationList(ctx) + if !found { + return types.ErrAuthorizationListNotFound + } + + policyRequired, err := authorizationsList.GetAuthorizedPolicy(msgURL) + if err != nil { + return errors.Wrap(types.ErrAuthorizationNotFound, fmt.Sprintf("msg: %v", msgURL)) + } + if policyRequired == types.PolicyType_groupEmpty { + return errors.Wrap(types.ErrInvalidPolicyType, fmt.Sprintf("Empty policy for msg: %v", msgURL)) + } + + policies, found := k.GetPolicies(ctx) + if !found { + return errors.Wrap(types.ErrPoliciesNotFound, fmt.Sprintf("msg: %v", msgURL)) + } + + return policies.CheckSigner(signer, policyRequired) +} diff --git a/x/authority/keeper/authorization_list_test.go b/x/authority/keeper/authorization_list_test.go index 008be61f25..a4c0d0e7bd 100644 --- a/x/authority/keeper/authorization_list_test.go +++ b/x/authority/keeper/authorization_list_test.go @@ -1,13 +1,16 @@ package keeper_test import ( + "fmt" "testing" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/require" keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/testutil/sample" "github.com/zeta-chain/zetacore/x/authority/types" + lightclienttypes "github.com/zeta-chain/zetacore/x/lightclient/types" ) func TestKeeper_GetAuthorizationList(t *testing.T) { @@ -47,3 +50,250 @@ func TestKeeper_SetAuthorizationList(t *testing.T) { require.Equal(t, newAuthorizationList, list) }) } + +func TestKeeper_CheckAuthorization(t *testing.T) { + t.Run("successfully check authorization", func(t *testing.T) { + k, ctx := keepertest.AuthorityKeeper(t) + signer := sample.AccAddress() + msg := lightclienttypes.MsgDisableHeaderVerification{ + Creator: signer, + } + authorizationList := types.AuthorizationList{Authorizations: []types.Authorization{ + { + MsgUrl: sdk.MsgTypeURL(&msg), + AuthorizedPolicy: types.PolicyType_groupOperational, + }, + }, + } + policies := types.Policies{ + Items: []*types.Policy{ + { + Address: signer, + PolicyType: types.PolicyType_groupOperational, + }, + }, + } + + k.SetPolicies(ctx, policies) + k.SetAuthorizationList(ctx, authorizationList) + + err := k.CheckAuthorization(ctx, &msg) + require.NoError(t, err) + }) + + t.Run("successfully check authorization against large authorization list", func(t *testing.T) { + k, ctx := keepertest.AuthorityKeeper(t) + signer := sample.AccAddress() + msg := lightclienttypes.MsgDisableHeaderVerification{ + Creator: signer, + } + authorizationList := types.DefaultAuthorizationsList() + // Add 300 more authorizations to the list + for i := 0; i < 100; i++ { + authorizationList.Authorizations = append( + authorizationList.Authorizations, + sample.AuthorizationList(fmt.Sprintf("sample%d", i)).Authorizations...) + } + policies := types.Policies{ + Items: []*types.Policy{ + { + Address: signer, + PolicyType: types.PolicyType_groupEmergency, + }, + }, + } + + k.SetPolicies(ctx, policies) + k.SetAuthorizationList(ctx, authorizationList) + + err := k.CheckAuthorization(ctx, &msg) + require.NoError(t, err) + + list, found := k.GetAuthorizationList(ctx) + require.True(t, found) + require.Equal(t, authorizationList, list) + }) + + t.Run("check authorization against fails against large authorization list", func(t *testing.T) { + k, ctx := keepertest.AuthorityKeeper(t) + signer := sample.AccAddress() + msg := lightclienttypes.MsgDisableHeaderVerification{ + Creator: signer, + } + authorizationList := types.AuthorizationList{} + // Add 300 more authorizations to the list + for i := 0; i < 100; i++ { + authorizationList.Authorizations = append( + authorizationList.Authorizations, + sample.AuthorizationList(fmt.Sprintf("sample%d", i)).Authorizations...) + } + policies := types.Policies{ + Items: []*types.Policy{ + { + Address: signer, + PolicyType: types.PolicyType_groupEmergency, + }, + }, + } + + k.SetPolicies(ctx, policies) + k.SetAuthorizationList(ctx, authorizationList) + + err := k.CheckAuthorization(ctx, &msg) + require.ErrorIs(t, err, types.ErrAuthorizationNotFound) + + list, found := k.GetAuthorizationList(ctx) + require.True(t, found) + require.Equal(t, authorizationList, list) + }) + + t.Run("unable to check authorization with multiple signers", func(t *testing.T) { + k, ctx := keepertest.AuthorityKeeper(t) + signer := sample.AccAddress() + msg := &sample.MultipleSignerMessage{} + authorizationList := types.AuthorizationList{Authorizations: []types.Authorization{ + { + MsgUrl: sdk.MsgTypeURL(msg), + AuthorizedPolicy: types.PolicyType_groupOperational, + }, + }, + } + policies := types.Policies{ + Items: []*types.Policy{ + { + Address: signer, + PolicyType: types.PolicyType_groupOperational, + }, + }, + } + k.SetPolicies(ctx, policies) + k.SetAuthorizationList(ctx, authorizationList) + + err := k.CheckAuthorization(ctx, msg) + require.ErrorIs(t, err, types.ErrSigners) + }) + + t.Run("unable to check authorization with no authorization list", func(t *testing.T) { + k, ctx := keepertest.AuthorityKeeper(t) + signer := sample.AccAddress() + msg := lightclienttypes.MsgDisableHeaderVerification{ + Creator: signer, + } + + policies := types.Policies{ + Items: []*types.Policy{ + { + Address: signer, + PolicyType: types.PolicyType_groupOperational, + }, + }, + } + k.SetPolicies(ctx, policies) + + err := k.CheckAuthorization(ctx, &msg) + require.ErrorIs(t, err, types.ErrAuthorizationListNotFound) + }) + + t.Run("unable to check authorization with no policies", func(t *testing.T) { + k, ctx := keepertest.AuthorityKeeper(t) + signer := sample.AccAddress() + msg := lightclienttypes.MsgDisableHeaderVerification{ + Creator: signer, + } + authorizationList := types.AuthorizationList{Authorizations: []types.Authorization{ + { + MsgUrl: sdk.MsgTypeURL(&msg), + AuthorizedPolicy: types.PolicyType_groupOperational, + }, + }, + } + k.SetAuthorizationList(ctx, authorizationList) + + err := k.CheckAuthorization(ctx, &msg) + require.ErrorIs(t, err, types.ErrPoliciesNotFound) + }) + + t.Run("unable to check authorization when the required authorization doesnt exist", func(t *testing.T) { + k, ctx := keepertest.AuthorityKeeper(t) + signer := sample.AccAddress() + msg := lightclienttypes.MsgDisableHeaderVerification{ + Creator: signer, + } + authorizationList := types.AuthorizationList{Authorizations: []types.Authorization{ + { + MsgUrl: "/zetachain.zetacore.observer.MsgDisableCCTX", + AuthorizedPolicy: types.PolicyType_groupOperational, + }, + }, + } + policies := types.Policies{ + Items: []*types.Policy{ + { + Address: signer, + PolicyType: types.PolicyType_groupOperational, + }, + }, + } + k.SetPolicies(ctx, policies) + k.SetAuthorizationList(ctx, authorizationList) + + err := k.CheckAuthorization(ctx, &msg) + require.ErrorIs(t, err, types.ErrAuthorizationNotFound) + }) + + t.Run("unable to check authorization when check signer fails", func(t *testing.T) { + k, ctx := keepertest.AuthorityKeeper(t) + signer := sample.AccAddress() + msg := lightclienttypes.MsgDisableHeaderVerification{ + Creator: signer, + } + authorizationList := types.AuthorizationList{Authorizations: []types.Authorization{ + { + MsgUrl: sdk.MsgTypeURL(&msg), + AuthorizedPolicy: types.PolicyType_groupOperational, + }, + }, + } + policies := types.Policies{ + Items: []*types.Policy{ + { + Address: signer, + PolicyType: types.PolicyType_groupAdmin, + }, + }, + } + k.SetPolicies(ctx, policies) + k.SetAuthorizationList(ctx, authorizationList) + + err := k.CheckAuthorization(ctx, &msg) + require.ErrorIs(t, err, types.ErrSignerDoesntMatch) + }) + + t.Run("unable to check authorization when the required policy is empty", func(t *testing.T) { + k, ctx := keepertest.AuthorityKeeper(t) + signer := sample.AccAddress() + msg := lightclienttypes.MsgDisableHeaderVerification{ + Creator: signer, + } + authorizationList := types.AuthorizationList{Authorizations: []types.Authorization{ + { + MsgUrl: sdk.MsgTypeURL(&msg), + AuthorizedPolicy: types.PolicyType_groupEmpty, + }, + }, + } + policies := types.Policies{ + Items: []*types.Policy{ + { + Address: signer, + PolicyType: types.PolicyType_groupOperational, + }, + }, + } + k.SetPolicies(ctx, policies) + k.SetAuthorizationList(ctx, authorizationList) + + err := k.CheckAuthorization(ctx, &msg) + require.ErrorIs(t, err, types.ErrInvalidPolicyType) + }) +} diff --git a/x/authority/keeper/policies.go b/x/authority/keeper/policies.go index 448f922749..cd04a23d34 100644 --- a/x/authority/keeper/policies.go +++ b/x/authority/keeper/policies.go @@ -24,17 +24,3 @@ func (k Keeper) GetPolicies(ctx sdk.Context) (val types.Policies, found bool) { k.cdc.MustUnmarshal(b, &val) return val, true } - -// IsAuthorized checks if the address is authorized for the given policy type -func (k Keeper) IsAuthorized(ctx sdk.Context, address string, policyType types.PolicyType) bool { - policies, found := k.GetPolicies(ctx) - if !found { - return false - } - for _, policy := range policies.Items { - if policy.Address == address && policy.PolicyType == policyType { - return true - } - } - return false -} diff --git a/x/authority/types/authorizations.go b/x/authority/types/authorization_list.go similarity index 100% rename from x/authority/types/authorizations.go rename to x/authority/types/authorization_list.go diff --git a/x/authority/types/authorizations_test.go b/x/authority/types/authorization_list_test.go similarity index 100% rename from x/authority/types/authorizations_test.go rename to x/authority/types/authorization_list_test.go diff --git a/x/authority/types/errors.go b/x/authority/types/errors.go index 9d9509e9d4..776132d525 100644 --- a/x/authority/types/errors.go +++ b/x/authority/types/errors.go @@ -7,4 +7,9 @@ var ( ErrInvalidAuthorizationList = errorsmod.Register(ModuleName, 1103, "invalid authorization list") ErrAuthorizationNotFound = errorsmod.Register(ModuleName, 1104, "authorization not found") ErrAuthorizationListNotFound = errorsmod.Register(ModuleName, 1105, "authorization list not found") + ErrSigners = errorsmod.Register(ModuleName, 1106, "policy transactions must have only one signer") + ErrMsgNotAuthorized = errorsmod.Register(ModuleName, 1107, "msg type is not authorized") + ErrPoliciesNotFound = errorsmod.Register(ModuleName, 1108, "policies not found") + ErrSignerDoesntMatch = errorsmod.Register(ModuleName, 1109, "signer doesn't match required policy") + ErrInvalidPolicyType = errorsmod.Register(ModuleName, 1110, "invalid policy type") ) diff --git a/x/authority/types/policies.go b/x/authority/types/policies.go index af3642bfd7..c11bfbba6d 100644 --- a/x/authority/types/policies.go +++ b/x/authority/types/policies.go @@ -3,6 +3,7 @@ package types import ( "fmt" + "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" ) @@ -54,3 +55,14 @@ func (p Policies) Validate() error { return nil } + +// CheckSigner checks if the signer is authorized for the given policy type +func (p Policies) CheckSigner(signer string, policyRequired PolicyType) error { + for _, policy := range p.Items { + if policy.Address == signer && policy.PolicyType == policyRequired { + return nil + } + } + return errors.Wrap(ErrSignerDoesntMatch, fmt.Sprintf("signer: %s, policy required for message: %s ", + signer, policyRequired.String())) +} diff --git a/x/authority/types/policies_test.go b/x/authority/types/policies_test.go index 7f8d2ff52c..57749b35f6 100644 --- a/x/authority/types/policies_test.go +++ b/x/authority/types/policies_test.go @@ -130,3 +130,134 @@ func TestPolicies_Validate(t *testing.T) { }) } } + +func TestPolicies_CheckSigner(t *testing.T) { + signer := sample.AccAddress() + tt := []struct { + name string + policies types.Policies + signer string + policyRequired types.PolicyType + expectedErr error + }{ + { + name: "successfully check signer for policyType groupEmergency", + policies: types.Policies{ + Items: []*types.Policy{ + { + Address: signer, + PolicyType: types.PolicyType_groupEmergency, + }, + { + Address: signer, + PolicyType: types.PolicyType_groupAdmin, + }, + { + Address: signer, + PolicyType: types.PolicyType_groupOperational, + }, + }, + }, + signer: signer, + policyRequired: types.PolicyType_groupEmergency, + expectedErr: nil, + }, + { + name: "successfully check signer for policyType groupOperational", + policies: types.Policies{ + Items: []*types.Policy{ + { + Address: signer, + PolicyType: types.PolicyType_groupEmergency, + }, + { + Address: signer, + PolicyType: types.PolicyType_groupAdmin, + }, + { + Address: signer, + PolicyType: types.PolicyType_groupOperational, + }, + }, + }, + signer: signer, + policyRequired: types.PolicyType_groupOperational, + expectedErr: nil, + }, + { + name: "successfully check signer for policyType groupAdmin", + policies: types.Policies{ + Items: []*types.Policy{ + { + Address: signer, + PolicyType: types.PolicyType_groupEmergency, + }, + { + Address: signer, + PolicyType: types.PolicyType_groupAdmin, + }, + { + Address: signer, + PolicyType: types.PolicyType_groupOperational, + }, + }, + }, + signer: signer, + policyRequired: types.PolicyType_groupAdmin, + expectedErr: nil, + }, + { + name: "signer not found", + policies: types.Policies{ + Items: []*types.Policy{ + { + Address: signer, + PolicyType: types.PolicyType_groupEmergency, + }, + { + Address: signer, + PolicyType: types.PolicyType_groupAdmin, + }, + { + Address: signer, + PolicyType: types.PolicyType_groupOperational, + }, + }, + }, + signer: sample.AccAddress(), + policyRequired: types.PolicyType_groupEmergency, + expectedErr: types.ErrSignerDoesntMatch, + }, + { + name: "policy required not found", + policies: types.Policies{ + Items: []*types.Policy{ + { + Address: signer, + PolicyType: types.PolicyType_groupAdmin, + }, + { + Address: sample.AccAddress(), + PolicyType: types.PolicyType_groupOperational, + }, + }, + }, + signer: signer, + policyRequired: types.PolicyType_groupEmergency, + expectedErr: types.ErrSignerDoesntMatch, + }, + { + name: "empty policies", + policies: types.Policies{}, + signer: signer, + policyRequired: types.PolicyType_groupEmergency, + expectedErr: types.ErrSignerDoesntMatch, + }, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + err := tc.policies.CheckSigner(tc.signer, tc.policyRequired) + require.ErrorIs(t, err, tc.expectedErr) + }) + } +}