From 29f98d935b6205467fb6051293693b0ad6795de9 Mon Sep 17 00:00:00 2001 From: Jun Kimura Date: Fri, 14 Jun 2024 21:48:15 +0900 Subject: [PATCH] fix `updateOperators` and add `update-operators` command Signed-off-by: Jun Kimura --- light-clients/lcp/types/codec.go | 1 + light-clients/lcp/types/update.go | 28 ++++++++-- relay/cmd.go | 89 +++++++++++++++++++++++++++++-- relay/lcp.go | 8 +++ relay/operator.go | 49 +++++++++++++++++ 5 files changed, 167 insertions(+), 8 deletions(-) diff --git a/light-clients/lcp/types/codec.go b/light-clients/lcp/types/codec.go index 6b1fe25..8e1de5d 100644 --- a/light-clients/lcp/types/codec.go +++ b/light-clients/lcp/types/codec.go @@ -20,5 +20,6 @@ func RegisterInterfaces(registry codectypes.InterfaceRegistry) { (*exported.ClientMessage)(nil), &UpdateClientMessage{}, &RegisterEnclaveKeyMessage{}, + &UpdateOperatorsMessage{}, ) } diff --git a/light-clients/lcp/types/update.go b/light-clients/lcp/types/update.go index 4f2b992..1b277bc 100644 --- a/light-clients/lcp/types/update.go +++ b/light-clients/lcp/types/update.go @@ -156,9 +156,13 @@ func (cs ClientState) verifyUpdateOperators(ctx sdk.Context, store storetypes.KV if err != nil { return err } + nextNonce := cs.OperatorsNonce + 1 + if message.Nonce != nextNonce { + return errorsmod.Wrapf(clienttypes.ErrInvalidHeader, "invalid nonce: expected=%v actual=%v clientID=%v", nextNonce, message.Nonce, clientID) + } newOperators, err := message.GetNewOperators() if err != nil { - return errorsmod.Wrapf(clienttypes.ErrInvalidHeader, "failed to get new operators: %v", err) + return errorsmod.Wrapf(clienttypes.ErrInvalidHeader, "failed to get new operators: %v clientID=%v", err, clientID) } signBytes, err := ComputeEIP712UpdateOperators( ctx.ChainID(), @@ -170,10 +174,25 @@ func (cs ClientState) verifyUpdateOperators(ctx sdk.Context, store storetypes.KV message.NewOperatorsThresholdDenominator, ) if err != nil { - return errorsmod.Wrapf(clienttypes.ErrInvalidHeader, "failed to compute sign bytes: %v", err) + return errorsmod.Wrapf(clienttypes.ErrInvalidHeader, "failed to compute sign bytes: err=%v clientID=%v", err, clientID) + } + commitment := crypto.Keccak256Hash(signBytes) + var success uint64 = 0 + for i, op := range cs.GetOperators() { + if len(message.Signatures[i]) == 0 { + continue + } + addr, err := RecoverAddress(commitment, message.Signatures[i]) + if err != nil { + return errorsmod.Wrapf(clienttypes.ErrInvalidHeader, "failed to recover operator address: err=%v clientID=%v", err, clientID) + } + if addr != op { + return errorsmod.Wrapf(clienttypes.ErrInvalidHeader, "invalid operator: expected=%v actual=%v clientID=%v", op, addr, clientID) + } + success++ } - if err := cs.VerifySignatures(ctx, store, crypto.Keccak256Hash(signBytes), message.Signatures); err != nil { - return errorsmod.Wrapf(clienttypes.ErrInvalidHeader, "failed to verify signatures: %v", err) + if success*cs.OperatorsThresholdDenominator < cs.OperatorsThresholdDenominator*uint64(len(cs.Operators)) { + return errorsmod.Wrapf(clienttypes.ErrInvalidHeader, "insufficient signatures: expected=%v actual=%v clientID=%v", cs.OperatorsThresholdDenominator, success, clientID) } return nil } @@ -259,6 +278,7 @@ func (cs ClientState) updateOperators(cdc codec.BinaryCodec, clientStore storety cs.Operators = message.NewOperators cs.OperatorsThresholdNumerator = message.NewOperatorsThresholdNumerator cs.OperatorsThresholdDenominator = message.NewOperatorsThresholdDenominator + cs.OperatorsNonce = message.Nonce setClientState(clientStore, cdc, &cs) return nil } diff --git a/relay/cmd.go b/relay/cmd.go index 3078bd9..8527736 100644 --- a/relay/cmd.go +++ b/relay/cmd.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/cosmos/cosmos-sdk/client/flags" + "github.com/ethereum/go-ethereum/common" "github.com/hyperledger-labs/yui-relayer/config" "github.com/hyperledger-labs/yui-relayer/core" "github.com/spf13/cobra" @@ -13,9 +14,13 @@ import ( ) const ( - flagSrc = "src" - flagHeight = "height" - flagELCClientID = "elc_client_id" + flagSrc = "src" + flagHeight = "height" + flagELCClientID = "elc_client_id" + flagOperators = "operators" + flagNonce = "nonce" + flagThresholdNumerator = "threshold_numerator" + flagThresholdDenominator = "threshold_denominator" ) func LCPCmd(ctx *config.Context) *cobra.Command { @@ -33,6 +38,7 @@ func LCPCmd(ctx *config.Context) *cobra.Command { updateEnclaveKeyCmd(ctx), activateClientCmd(ctx), removeEnclaveKeyInfoCmd(ctx), + updateOperatorsCmd(ctx), ) return cmd @@ -276,6 +282,53 @@ func removeEnclaveKeyInfoCmd(ctx *config.Context) *cobra.Command { return srcFlag(cmd) } +func updateOperatorsCmd(ctx *config.Context) *cobra.Command { + cmd := &cobra.Command{ + Use: "update-operators [path]", + Short: "Update operators", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + c, src, dst, err := ctx.Config.ChainsFromPath(args[0]) + if err != nil { + return err + } + var ( + target *core.ProvableChain + counterparty *core.ProvableChain + ) + if viper.GetBool(flagSrc) { + target = c[src] + counterparty = c[dst] + } else { + target = c[dst] + counterparty = c[src] + } + prover := target.Prover.(*Prover) + operators := viper.GetStringSlice(flagOperators) + threshold := Fraction{ + Numerator: viper.GetUint64(flagThresholdNumerator), + Denominator: viper.GetUint64(flagThresholdDenominator), + } + nonce := viper.GetUint64(flagNonce) + + var newOperators []common.Address + for _, op := range operators { + if !common.IsHexAddress(op) { + return fmt.Errorf("invalid operator address: %s", op) + } + newOperators = append(newOperators, common.HexToAddress(op)) + } + return prover.updateOperators(counterparty, nonce, newOperators, threshold) + }, + } + cmd = thresholdFlag(nonceFlag(operatorsFlag(srcFlag(cmd)))) + cmd.MarkFlagRequired(flagOperators) + cmd.MarkFlagRequired(flagThresholdNumerator) + cmd.MarkFlagRequired(flagThresholdDenominator) + cmd.MarkFlagRequired(flagNonce) + return cmd +} + func srcFlag(cmd *cobra.Command) *cobra.Command { cmd.Flags().BoolP(flagSrc, "", true, "a boolean value whether src is the target chain") if err := viper.BindPFlag(flagSrc, cmd.Flags().Lookup(flagSrc)); err != nil { @@ -293,9 +346,37 @@ func heightFlag(cmd *cobra.Command) *cobra.Command { } func elcClientIDFlag(cmd *cobra.Command) *cobra.Command { - cmd.Flags().StringP("elc_client_id", "", "", "a client ID of the ELC client") + cmd.Flags().StringP(flagELCClientID, "", "", "a client ID of the ELC client") if err := viper.BindPFlag(flagELCClientID, cmd.Flags().Lookup(flagELCClientID)); err != nil { panic(err) } return cmd } + +func operatorsFlag(cmd *cobra.Command) *cobra.Command { + cmd.Flags().StringSliceP(flagOperators, "", nil, "new operator addresses") + if err := viper.BindPFlag(flagOperators, cmd.Flags().Lookup(flagOperators)); err != nil { + panic(err) + } + return cmd +} + +func nonceFlag(cmd *cobra.Command) *cobra.Command { + cmd.Flags().Uint64P(flagNonce, "", 0, "a nonce") + if err := viper.BindPFlag(flagNonce, cmd.Flags().Lookup(flagNonce)); err != nil { + panic(err) + } + return cmd +} + +func thresholdFlag(cmd *cobra.Command) *cobra.Command { + cmd.Flags().Uint64P(flagThresholdNumerator, "", 0, "a numerator of new threshold") + cmd.Flags().Uint64P(flagThresholdDenominator, "", 0, "a denominator of new threshold") + if err := viper.BindPFlag(flagThresholdNumerator, cmd.Flags().Lookup(flagThresholdNumerator)); err != nil { + panic(err) + } + if err := viper.BindPFlag(flagThresholdDenominator, cmd.Flags().Lookup(flagThresholdDenominator)); err != nil { + panic(err) + } + return cmd +} diff --git a/relay/lcp.go b/relay/lcp.go index 729193f..9bfb62a 100644 --- a/relay/lcp.go +++ b/relay/lcp.go @@ -370,6 +370,14 @@ func (pr *Prover) ComputeEIP712RegisterEnclaveKeyHash(report string) (common.Has return crypto.Keccak256Hash(bz), nil } +func (pr *Prover) ComputeEIP712UpdateOperatorsHash(nonce uint64, newOperators []common.Address, thresholdNumerator, thresholdDenominator uint64) (common.Hash, error) { + bz, err := lcptypes.ComputeEIP712UpdateOperatorsWithSalt(pr.computeEIP712ChainSalt(), pr.path.ClientID, nonce, newOperators, thresholdNumerator, thresholdDenominator) + if err != nil { + return common.Hash{}, err + } + return crypto.Keccak256Hash(bz), nil +} + func (pr *Prover) computeEIP712ChainSalt() common.Hash { switch pr.config.ChainType() { case ChainTypeEVM: diff --git a/relay/operator.go b/relay/operator.go index b1cd557..df1b433 100644 --- a/relay/operator.go +++ b/relay/operator.go @@ -5,8 +5,12 @@ import ( "fmt" "strings" + sdk "github.com/cosmos/cosmos-sdk/types" + clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types" + lcptypes "github.com/datachainlab/lcp-go/light-clients/lcp/types" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto/secp256k1" + "github.com/hyperledger-labs/yui-relayer/core" ) func (pr *Prover) OperatorSign(commitment [32]byte) ([]byte, error) { @@ -44,3 +48,48 @@ func (pr *Prover) GetOperatorsThreshold() Fraction { } return pr.config.OperatorsThreshold } + +func (pr *Prover) updateOperators(verifier core.Chain, nonce uint64, newOperators []common.Address, threshold Fraction) error { + if nonce == 0 { + return fmt.Errorf("invalid nonce: %v", nonce) + } + if threshold.Numerator == 0 || threshold.Denominator == 0 { + return fmt.Errorf("invalid threshold: %v", threshold) + } + commitment, err := pr.ComputeEIP712UpdateOperatorsHash( + nonce, + newOperators, + threshold.Numerator, + threshold.Denominator, + ) + if err != nil { + return err + } + sig, err := pr.OperatorSign(commitment) + if err != nil { + return err + } + var ops [][]byte + for _, operator := range newOperators { + ops = append(ops, operator.Bytes()) + } + message := &lcptypes.UpdateOperatorsMessage{ + Nonce: nonce, + NewOperators: ops, + NewOperatorsThresholdNumerator: threshold.Numerator, + NewOperatorsThresholdDenominator: threshold.Denominator, + Signatures: [][]byte{sig}, + } + signer, err := verifier.GetAddress() + if err != nil { + return err + } + msg, err := clienttypes.NewMsgUpdateClient(verifier.Path().ClientID, message, signer.String()) + if err != nil { + return err + } + if _, err := verifier.SendMsgs([]sdk.Msg{msg}); err != nil { + return err + } + return nil +}