diff --git a/cmd/zetaclientd/encrypt_tss.go b/cmd/zetaclientd/encrypt_tss.go index 6fca9064cb..e8e4a69807 100644 --- a/cmd/zetaclientd/encrypt_tss.go +++ b/cmd/zetaclientd/encrypt_tss.go @@ -1,17 +1,14 @@ package main import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "crypto/sha256" "encoding/json" - "errors" - "io" "os" "path/filepath" + "github.com/pkg/errors" "github.com/spf13/cobra" + + "github.com/zeta-chain/zetacore/pkg/crypto" ) var encTssCmd = &cobra.Command{ @@ -25,6 +22,7 @@ func init() { RootCmd.AddCommand(encTssCmd) } +// EncryptTSSFile encrypts the given file with the given secret key func EncryptTSSFile(_ *cobra.Command, args []string) error { filePath := args[0] secretKey := args[1] @@ -39,29 +37,11 @@ func EncryptTSSFile(_ *cobra.Command, args []string) error { return errors.New("file does not contain valid json, may already be encrypted") } - block, err := aes.NewCipher(getFragmentSeed(secretKey)) - if err != nil { - return err - } - - // Creating GCM mode - gcm, err := cipher.NewGCM(block) + // encrypt the data + cipherText, err := crypto.EncryptAES256GCM(data, secretKey) if err != nil { - return err - } - // Generating random nonce - nonce := make([]byte, gcm.NonceSize()) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return err + return errors.Wrap(err, "failed to encrypt data") } - cipherText := gcm.Seal(nonce, nonce, data, nil) return os.WriteFile(filePath, cipherText, 0o600) } - -func getFragmentSeed(password string) []byte { - h := sha256.New() - h.Write([]byte(password)) - seed := h.Sum(nil) - return seed -} diff --git a/cmd/zetaclientd/import_relayer_keys.go b/cmd/zetaclientd/import_relayer_keys.go new file mode 100644 index 0000000000..fef20f6ed1 --- /dev/null +++ b/cmd/zetaclientd/import_relayer_keys.go @@ -0,0 +1,178 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/pkg/errors" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + + "github.com/zeta-chain/zetacore/pkg/crypto" + zetaos "github.com/zeta-chain/zetacore/pkg/os" + "github.com/zeta-chain/zetacore/zetaclient/keys" +) + +var CmdImportRelayerKey = &cobra.Command{ + Use: "import-relayer-key [network] [private-key] [password] [relayer-key-path]", + Short: "Import a relayer private key", + Example: `zetaclientd import-relayer-key --network=7 --private-key=3EMjCcCJg53fMEGVj13UPQpo6py9AKKyLE2qroR4yL1SvAN2tUznBvDKRYjntw7m6Jof1R2CSqjTddL27rEb6sFQ --password=my_password`, + RunE: ImportRelayerKey, +} + +var CmdRelayerAddress = &cobra.Command{ + Use: "relayer-address [network] [password] [relayer-key-path]", + Short: "Show the relayer address", + Example: `zetaclientd relayer-address --network=7 --password=my_password`, + RunE: ShowRelayerAddress, +} + +var importArgs = importRelayerKeyArguments{} +var addressArgs = relayerAddressArguments{} + +// importRelayerKeyArguments is the struct that holds the arguments for the import command +type importRelayerKeyArguments struct { + network int32 + privateKey string + password string + relayerKeyPath string +} + +// relayerAddressArguments is the struct that holds the arguments for the show command +type relayerAddressArguments struct { + network int32 + password string + relayerKeyPath string +} + +func init() { + RootCmd.AddCommand(CmdImportRelayerKey) + RootCmd.AddCommand(CmdRelayerAddress) + + // resolve default relayer key path + defaultRelayerKeyPath := "~/.zetacored/relayer-keys" + defaultRelayerKeyPath, err := zetaos.ExpandHomeDir(defaultRelayerKeyPath) + if err != nil { + log.Fatal().Err(err).Msg("failed to resolve default relayer key path") + } + + CmdImportRelayerKey.Flags().Int32Var(&importArgs.network, "network", 7, "network id, (7: solana)") + CmdImportRelayerKey.Flags(). + StringVar(&importArgs.privateKey, "private-key", "", "the relayer private key to import") + CmdImportRelayerKey.Flags(). + StringVar(&importArgs.password, "password", "", "the password to encrypt the private key") + CmdImportRelayerKey.Flags(). + StringVar(&importArgs.relayerKeyPath, "relayer-key-path", defaultRelayerKeyPath, "path to relayer keys") + + CmdRelayerAddress.Flags().Int32Var(&addressArgs.network, "network", 7, "network id, (7:solana)") + CmdRelayerAddress.Flags(). + StringVar(&addressArgs.password, "password", "", "the password to decrypt the private key") + CmdRelayerAddress.Flags(). + StringVar(&addressArgs.relayerKeyPath, "relayer-key-path", defaultRelayerKeyPath, "path to relayer keys") +} + +// ImportRelayerKey imports a relayer private key +func ImportRelayerKey(_ *cobra.Command, _ []string) error { + // validate private key and password + if importArgs.privateKey == "" { + return errors.New("must provide a private key") + } + if importArgs.password == "" { + return errors.New("must provide a password") + } + + // resolve the relayer key file path + keyPath, fileName, err := resolveRelayerKeyPath(importArgs.network, importArgs.relayerKeyPath) + if err != nil { + return errors.Wrap(err, "failed to resolve relayer key file path") + } + + // create path (owner `rwx` permissions) if it does not exist + if _, err := os.Stat(keyPath); os.IsNotExist(err) { + if err := os.MkdirAll(keyPath, 0o700); err != nil { + return errors.Wrapf(err, "failed to create relayer key path: %s", keyPath) + } + } + + // avoid overwriting existing key file + if zetaos.FileExists(fileName) { + return errors.Errorf( + "relayer key %s already exists, please backup and remove it before importing a new key", + fileName, + ) + } + + // encrypt the private key + ciphertext, err := crypto.EncryptAES256GCMBase64(importArgs.privateKey, importArgs.password) + if err != nil { + return errors.Wrap(err, "private key encryption failed") + } + + // construct the relayer key struct and write to file as json + keyData, err := json.Marshal(keys.RelayerKey{PrivateKey: ciphertext}) + if err != nil { + return errors.Wrap(err, "failed to marshal relayer key") + } + + // create relay key file (owner `rw` permissions) + err = os.WriteFile(fileName, keyData, 0o600) + if err != nil { + return errors.Wrapf(err, "failed to create relayer key file: %s", fileName) + } + fmt.Printf("successfully imported relayer key: %s\n", fileName) + + return nil +} + +// ShowRelayerAddress shows the relayer address +func ShowRelayerAddress(_ *cobra.Command, _ []string) error { + // resolve the relayer key file path + _, fileName, err := resolveRelayerKeyPath(addressArgs.network, addressArgs.relayerKeyPath) + if err != nil { + return errors.Wrap(err, "failed to resolve relayer key file path") + } + + // read the relayer key file + relayerKey, err := keys.ReadRelayerKeyFromFile(fileName) + if err != nil { + return err + } + + // decrypt the private key + privateKey, err := crypto.DecryptAES256GCMBase64(relayerKey.PrivateKey, addressArgs.password) + if err != nil { + return errors.Wrap(err, "private key decryption failed") + } + relayerKey.PrivateKey = privateKey + + // resolve the address + networkName, address, err := relayerKey.ResolveAddress(addressArgs.network) + if err != nil { + return errors.Wrap(err, "failed to resolve relayer address") + } + fmt.Printf("relayer address (%s): %s\n", networkName, address) + + return nil +} + +// resolveRelayerKeyPath is a helper function to resolve the relayer key file path and name +func resolveRelayerKeyPath(network int32, relayerKeyPath string) (string, string, error) { + // get relayer key file name by network + name, err := keys.GetRelayerKeyFileByNetwork(network) + if err != nil { + return "", "", errors.Wrap(err, "failed to get relayer key file name") + } + + // resolve relayer key path if it contains a tilde + keyPath, err := zetaos.ExpandHomeDir(relayerKeyPath) + if err != nil { + return "", "", errors.Wrap(err, "failed to resolve relayer key path") + } + + // build file name + fileName := filepath.Join(keyPath, name) + + return keyPath, fileName, err +} diff --git a/contrib/localnet/orchestrator/start-zetae2e.sh b/contrib/localnet/orchestrator/start-zetae2e.sh index c885264c6f..614d53bc47 100644 --- a/contrib/localnet/orchestrator/start-zetae2e.sh +++ b/contrib/localnet/orchestrator/start-zetae2e.sh @@ -44,52 +44,52 @@ sleep 2 # unlock the default account account address=$(yq -r '.default_account.evm_address' config.yml) echo "funding deployer address ${address} with 10000 Ether" -geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 +geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 > /dev/null # unlock erc20 tester accounts address=$(yq -r '.additional_accounts.user_erc20.evm_address' config.yml) echo "funding erc20 address ${address} with 10000 Ether" -geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 +geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 > /dev/null # unlock zeta tester accounts address=$(yq -r '.additional_accounts.user_zeta_test.evm_address' config.yml) echo "funding zeta tester address ${address} with 10000 Ether" -geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 +geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 > /dev/null # unlock zevm message passing tester accounts address=$(yq -r '.additional_accounts.user_zevm_mp_test.evm_address' config.yml) echo "funding zevm mp tester address ${address} with 10000 Ether" -geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 +geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 > /dev/null # unlock bitcoin tester accounts address=$(yq -r '.additional_accounts.user_bitcoin.evm_address' config.yml) echo "funding bitcoin tester address ${address} with 10000 Ether" -geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 +geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 > /dev/null # unlock solana tester accounts address=$(yq -r '.additional_accounts.user_solana.evm_address' config.yml) echo "funding solana tester address ${address} with 10000 Ether" -geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 +geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 > /dev/null # unlock ethers tester accounts address=$(yq -r '.additional_accounts.user_ether.evm_address' config.yml) echo "funding ether tester address ${address} with 10000 Ether" -geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 +geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 > /dev/null # unlock miscellaneous tests accounts address=$(yq -r '.additional_accounts.user_misc.evm_address' config.yml) echo "funding misc tester address ${address} with 10000 Ether" -geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 +geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 > /dev/null # unlock admin erc20 tests accounts address=$(yq -r '.additional_accounts.user_admin.evm_address' config.yml) echo "funding admin tester address ${address} with 10000 Ether" -geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 +geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 > /dev/null # unlock migration tests accounts address=$(yq -r '.additional_accounts.user_migration.evm_address' config.yml) echo "funding migration tester address ${address} with 10000 Ether" -geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 +geth --exec "eth.sendTransaction({from: eth.coinbase, to: '${address}', value: web3.toWei(10000,'ether')})" attach http://eth:8545 > /dev/null # unlock local solana relayer accounts solana_url=$(yq -r '.rpcs.solana' config.yml) diff --git a/pkg/chains/chain.go b/pkg/chains/chain.go index 221da26a32..06ce4af289 100644 --- a/pkg/chains/chain.go +++ b/pkg/chains/chain.go @@ -147,6 +147,12 @@ func (chain Chain) IsEmpty() bool { return strings.TrimSpace(chain.String()) == "" } +// GetNetworkName returns the network name from the network ID +func GetNetworkName(network int32) (string, bool) { + name, found := Network_name[network] + return name, found +} + // GetChainFromChainID returns the chain from the chain ID // additionalChains is a list of additional chains to search from // in practice, it is used in the protocol to dynamically support new chains without doing an upgrade diff --git a/pkg/chains/chain_test.go b/pkg/chains/chain_test.go index 23bc6adf18..0bbfbc5263 100644 --- a/pkg/chains/chain_test.go +++ b/pkg/chains/chain_test.go @@ -396,11 +396,19 @@ func TestChain_IsEmpty(t *testing.T) { require.False(t, chains.ZetaChainMainnet.IsEmpty()) } +func TestGetNetworkName(t *testing.T) { + network := int32(chains.Network_solana) + name, found := chains.GetNetworkName(network) + nameExpected, foundExpected := chains.Network_name[network] + require.Equal(t, nameExpected, name) + require.Equal(t, foundExpected, found) +} + func TestGetChainFromChainID(t *testing.T) { chain, found := chains.GetChainFromChainID(chains.ZetaChainMainnet.ChainId, []chains.Chain{}) require.EqualValues(t, chains.ZetaChainMainnet, chain) require.True(t, found) - chain, found = chains.GetChainFromChainID(9999, []chains.Chain{}) + _, found = chains.GetChainFromChainID(9999, []chains.Chain{}) require.False(t, found) } diff --git a/pkg/crypto/aes256_gcm.go b/pkg/crypto/aes256_gcm.go new file mode 100644 index 0000000000..f615fa79ea --- /dev/null +++ b/pkg/crypto/aes256_gcm.go @@ -0,0 +1,118 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + io "io" + + "github.com/pkg/errors" +) + +// EncryptAES256GCMBase64 encrypts the given string plaintext using AES-256-GCM with the given key and returns the base64-encoded ciphertext. +func EncryptAES256GCMBase64(plaintext string, encryptKey string) (string, error) { + // validate the input + if plaintext == "" { + return "", errors.New("plaintext must not be empty") + } + if encryptKey == "" { + return "", errors.New("encrypt key must not be empty") + } + + // encrypt the plaintext + ciphertext, err := EncryptAES256GCM([]byte(plaintext), encryptKey) + if err != nil { + return "", errors.Wrap(err, "failed to encrypt string plaintext") + } + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +// DecryptAES256GCMBase64 decrypts the given base64-encoded ciphertext using AES-256-GCM with the given key. +func DecryptAES256GCMBase64(ciphertextBase64 string, decryptKey string) (string, error) { + // validate the input + if ciphertextBase64 == "" { + return "", errors.New("ciphertext must not be empty") + } + if decryptKey == "" { + return "", errors.New("decrypt key must not be empty") + } + + // decode the base64-encoded ciphertext + ciphertext, err := base64.StdEncoding.DecodeString(ciphertextBase64) + if err != nil { + return "", errors.Wrap(err, "failed to decode base64 ciphertext") + } + + // decrypt the ciphertext + plaintext, err := DecryptAES256GCM(ciphertext, decryptKey) + if err != nil { + return "", errors.Wrap(err, "failed to decrypt ciphertext") + } + return string(plaintext), nil +} + +// EncryptAES256GCM encrypts the given plaintext using AES-256-GCM with the given key. +func EncryptAES256GCM(plaintext []byte, encryptKey string) ([]byte, error) { + block, err := aes.NewCipher(getAESKey(encryptKey)) + if err != nil { + return nil, err + } + + // create GCM mode + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + // generate random nonce + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + // encrypt the plaintext + ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) + + return ciphertext, nil +} + +// DecryptAES256GCM decrypts the given ciphertext using AES-256-GCM with the given key. +func DecryptAES256GCM(ciphertext []byte, encryptKey string) ([]byte, error) { + block, err := aes.NewCipher(getAESKey(encryptKey)) + if err != nil { + return nil, err + } + + // create GCM mode + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + // get the nonce size + nonceSize := gcm.NonceSize() + if len(ciphertext) < nonceSize { + return nil, err + } + + // extract the nonce from the ciphertext + nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] + + // decrypt the ciphertext + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, err + } + + return plaintext, nil +} + +// getAESKey uses SHA-256 to create a 32-byte key AES encryption. +func getAESKey(key string) []byte { + h := sha256.New() + h.Write([]byte(key)) + + return h.Sum(nil) +} diff --git a/pkg/crypto/aes256_gcm_test.go b/pkg/crypto/aes256_gcm_test.go new file mode 100644 index 0000000000..ff83698ddb --- /dev/null +++ b/pkg/crypto/aes256_gcm_test.go @@ -0,0 +1,187 @@ +package crypto_test + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/pkg/crypto" +) + +func Test_EncryptDecryptAES256GCM(t *testing.T) { + tests := []struct { + name string + plaintext string + encryptKey string + decryptKey string + modifyFunc func([]byte) []byte + fail bool + }{ + { + name: "Successful encryption and decryption", + plaintext: "Hello, World!", + encryptKey: "my_password", + decryptKey: "my_password", + fail: false, + }, + { + name: "Decryption with incorrect key should fail", + plaintext: "Hello, World!", + encryptKey: "my_password", + decryptKey: "my_password2", + fail: true, + }, + { + name: "Decryption with corrupted ciphertext should fail", + plaintext: "Hello, World!", + encryptKey: "my_password", + decryptKey: "my_password", + modifyFunc: func(ciphertext []byte) []byte { + // flip the last bit of the ciphertext + ciphertext[len(ciphertext)-1] ^= 0x01 + return ciphertext + }, + fail: true, + }, + { + name: "Decryption with incorrect nonce should fail", + plaintext: "Hello, World!", + encryptKey: "my_password", + decryptKey: "my_password", + modifyFunc: func(ciphertext []byte) []byte { + // flip the first bit of the nonce + ciphertext[0] ^= 0x01 + return ciphertext + }, + fail: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + encrypted, err := crypto.EncryptAES256GCM([]byte(tt.plaintext), tt.encryptKey) + require.NoError(t, err) + + // modify the encrypted data if needed + if tt.modifyFunc != nil { + encrypted = tt.modifyFunc(encrypted) + } + + // decrypt the data + decrypted, err := crypto.DecryptAES256GCM(encrypted, tt.decryptKey) + if tt.fail { + require.Error(t, err) + return + } + + require.True(t, bytes.Equal(decrypted, []byte(tt.plaintext)), "decrypted plaintext does not match") + }) + } +} + +func Test_EncryptAES256GCMBase64(t *testing.T) { + tests := []struct { + name string + plaintext string + encryptKey string + decryptKey string + errorMessage string + }{ + { + name: "Successful encryption and decryption", + plaintext: "Hello, World!", + encryptKey: "my_password", + decryptKey: "my_password", + }, + { + name: "Encryption with empty plaintext should fail", + plaintext: "", + errorMessage: "plaintext must not be empty", + }, + { + name: "Encryption with empty encrypt key should fail", + plaintext: "Hello, World!", + encryptKey: "", + errorMessage: "encrypt key must not be empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // encrypt the data + ciphertextBase64, err := crypto.EncryptAES256GCMBase64(tt.plaintext, tt.encryptKey) + if tt.errorMessage != "" { + require.ErrorContains(t, err, tt.errorMessage) + return + } + + // decrypt the data + decrypted, err := crypto.DecryptAES256GCMBase64(ciphertextBase64, tt.decryptKey) + require.NoError(t, err) + + require.Equal(t, tt.plaintext, decrypted) + }) + } +} + +func Test_DecryptAES256GCMBase64(t *testing.T) { + tests := []struct { + name string + ciphertextBase64 string + plaintext string + decryptKey string + modifyFunc func(string) string + errorMessage string + }{ + { + name: "Successful decryption", + ciphertextBase64: "CXLWgHdVeZQwVOZZyHeZ5n5VB+eVSLaWFF0v0QOm9DyB7XSiHDwhNwQ=", + plaintext: "Hello, World!", + decryptKey: "my_password", + }, + { + name: "Decryption with empty ciphertext should fail", + ciphertextBase64: "", + decryptKey: "my_password", + errorMessage: "ciphertext must not be empty", + }, + { + name: "Decryption with empty decrypt key should fail", + ciphertextBase64: "CXLWgHdVeZQwVOZZyHeZ5n5VB+eVSLaWFF0v0QOm9DyB7XSiHDwhNwQ=", + decryptKey: "", + errorMessage: "decrypt key must not be empty", + }, + { + name: "Decryption with invalid base64 ciphertext should fail", + ciphertextBase64: "CXLWgHdVeZQwVOZZyHeZ5n5VB*eVSLaWFF0v0QOm9DyB7XSiHDwhNwQ=", // use '*' instead of '+' + decryptKey: "my_password", + errorMessage: "failed to decode base64 ciphertext", + }, + { + name: "Decryption with incorrect decrypt key should fail", + ciphertextBase64: "CXLWgHdVeZQwVOZZyHeZ5n5VB+eVSLaWFF0v0QOm9DyB7XSiHDwhNwQ=", + decryptKey: "my_password2", + errorMessage: "failed to decrypt ciphertext", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ciphertextBase64 := tt.ciphertextBase64 + + // modify the encrypted data if needed + if tt.modifyFunc != nil { + ciphertextBase64 = tt.modifyFunc(ciphertextBase64) + } + + // decrypt the data + decrypted, err := crypto.DecryptAES256GCMBase64(ciphertextBase64, tt.decryptKey) + if tt.errorMessage != "" { + require.ErrorContains(t, err, tt.errorMessage) + return + } + + require.Equal(t, tt.plaintext, decrypted) + }) + } +} diff --git a/pkg/os/path.go b/pkg/os/path.go new file mode 100644 index 0000000000..abf8368c64 --- /dev/null +++ b/pkg/os/path.go @@ -0,0 +1,33 @@ +package os + +import ( + "os" + "os/user" + "path/filepath" + "strings" +) + +// ExpandHomeDir expands a leading tilde in the path to the home directory of the current user. +// ~someuser/tmp will not be expanded. +func ExpandHomeDir(p string) (string, error) { + if p == "~" || + strings.HasPrefix(p, "~/") || + strings.HasPrefix(p, "~\\") { + usr, err := user.Current() + if err != nil { + return p, err + } + + p = filepath.Join(usr.HomeDir, p[1:]) + } + return filepath.Clean(p), nil +} + +// FileExists checks if a file exists. +func FileExists(filePath string) bool { + _, err := os.Stat(filePath) + if os.IsNotExist(err) { + return false + } + return err == nil +} diff --git a/pkg/os/path_test.go b/pkg/os/path_test.go new file mode 100644 index 0000000000..d02c55ef4e --- /dev/null +++ b/pkg/os/path_test.go @@ -0,0 +1,83 @@ +package os_test + +import ( + "os" + "os/user" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + zetaos "github.com/zeta-chain/zetacore/pkg/os" + "github.com/zeta-chain/zetacore/testutil/sample" +) + +func TestResolveHome(t *testing.T) { + usr, err := user.Current() + require.NoError(t, err) + + testCases := []struct { + name string + pathIn string + expected string + fail bool + }{ + { + name: `should resolve home with leading "~/"`, + pathIn: "~/tmp/file.json", + expected: filepath.Clean(filepath.Join(usr.HomeDir, "tmp/file.json")), + }, + { + name: "should resolve '~'", + pathIn: `~`, + expected: filepath.Clean(filepath.Join(usr.HomeDir, "")), + }, + { + name: "should not resolve '~someuser/tmp'", + pathIn: `~someuser/tmp`, + expected: `~someuser/tmp`, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + pathOut, err := zetaos.ExpandHomeDir(tc.pathIn) + require.NoError(t, err) + require.Equal(t, tc.expected, pathOut) + }) + } +} + +func TestFileExists(t *testing.T) { + path := sample.CreateTempDir(t) + + // create a test file + existingFile := filepath.Join(path, "test.txt") + _, err := os.Create(existingFile) + require.NoError(t, err) + + testCases := []struct { + name string + file string + expected bool + }{ + { + name: "should return true for existing file", + file: existingFile, + expected: true, + }, + { + name: "should return false for non-existing file", + file: filepath.Join(path, "non-existing.txt"), + expected: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + exists := zetaos.FileExists(tc.file) + require.Equal(t, tc.expected, exists) + }) + } +} diff --git a/rpc/namespaces/ethereum/debug/api.go b/rpc/namespaces/ethereum/debug/api.go index 3496da92d8..1b07828b80 100644 --- a/rpc/namespaces/ethereum/debug/api.go +++ b/rpc/namespaces/ethereum/debug/api.go @@ -37,6 +37,7 @@ import ( evmtypes "github.com/evmos/ethermint/x/evm/types" stderrors "github.com/pkg/errors" + zetaos "github.com/zeta-chain/zetacore/pkg/os" "github.com/zeta-chain/zetacore/rpc/backend" rpctypes "github.com/zeta-chain/zetacore/rpc/types" ) @@ -199,7 +200,7 @@ func (a *API) StartCPUProfile(file string) error { a.logger.Debug("CPU profiling already in progress") return errors.New("CPU profiling already in progress") default: - fp, err := ExpandHome(file) + fp, err := zetaos.ExpandHomeDir(file) if err != nil { a.logger.Debug("failed to get filepath for the CPU profile file", "error", err.Error()) return err diff --git a/rpc/namespaces/ethereum/debug/trace.go b/rpc/namespaces/ethereum/debug/trace.go index 28ba1c8043..ae35b16fc2 100644 --- a/rpc/namespaces/ethereum/debug/trace.go +++ b/rpc/namespaces/ethereum/debug/trace.go @@ -25,6 +25,8 @@ import ( "runtime/trace" stderrors "github.com/pkg/errors" + + zetaos "github.com/zeta-chain/zetacore/pkg/os" ) // StartGoTrace turns on tracing, writing to the given file. @@ -37,7 +39,7 @@ func (a *API) StartGoTrace(file string) error { a.logger.Debug("trace already in progress") return errors.New("trace already in progress") } - fp, err := ExpandHome(file) + fp, err := zetaos.ExpandHomeDir(file) if err != nil { a.logger.Debug("failed to get filepath for the CPU profile file", "error", err.Error()) return err diff --git a/rpc/namespaces/ethereum/debug/utils.go b/rpc/namespaces/ethereum/debug/utils.go index 277c37df56..ae3f0c5ba5 100644 --- a/rpc/namespaces/ethereum/debug/utils.go +++ b/rpc/namespaces/ethereum/debug/utils.go @@ -17,13 +17,12 @@ package debug import ( "os" - "os/user" - "path/filepath" "runtime/pprof" - "strings" "github.com/cometbft/cometbft/libs/log" "github.com/cosmos/cosmos-sdk/server" + + zetaos "github.com/zeta-chain/zetacore/pkg/os" ) // isCPUProfileConfigurationActivated checks if cpuprofile was configured via flag @@ -33,25 +32,11 @@ func isCPUProfileConfigurationActivated(ctx *server.Context) bool { return ctx.Viper.GetString("cpu-profile") != "" } -// ExpandHome expands home directory in file paths. -// ~someuser/tmp will not be expanded. -func ExpandHome(p string) (string, error) { - if strings.HasPrefix(p, "~/") || strings.HasPrefix(p, "~\\") { - usr, err := user.Current() - if err != nil { - return p, err - } - home := usr.HomeDir - p = home + p[1:] - } - return filepath.Clean(p), nil -} - // writeProfile writes the data to a file func writeProfile(name, file string, log log.Logger) error { p := pprof.Lookup(name) log.Info("Writing profile records", "count", p.Count(), "type", name, "dump", file) - fp, err := ExpandHome(file) + fp, err := zetaos.ExpandHomeDir(file) if err != nil { return err } diff --git a/server/start.go b/server/start.go index 64bb1db4e6..79abab78bf 100644 --- a/server/start.go +++ b/server/start.go @@ -58,7 +58,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - ethdebug "github.com/zeta-chain/zetacore/rpc/namespaces/ethereum/debug" + zetaos "github.com/zeta-chain/zetacore/pkg/os" "github.com/zeta-chain/zetacore/server/config" srvflags "github.com/zeta-chain/zetacore/server/flags" ) @@ -337,7 +337,7 @@ func startInProcess(ctx *server.Context, clientCtx client.Context, opts StartOpt logger := ctx.Logger if cpuProfile := ctx.Viper.GetString(srvflags.CPUProfile); cpuProfile != "" { - fp, err := ethdebug.ExpandHome(cpuProfile) + fp, err := zetaos.ExpandHomeDir(cpuProfile) if err != nil { ctx.Logger.Debug("failed to get filepath for the CPU profile file", "error", err.Error()) return err diff --git a/zetaclient/keys/relayer_keys.go b/zetaclient/keys/relayer_keys.go index cf7b20f724..97dd7ca8b2 100644 --- a/zetaclient/keys/relayer_keys.go +++ b/zetaclient/keys/relayer_keys.go @@ -2,14 +2,15 @@ package keys import ( "encoding/json" - "fmt" "io" "os" "path" + "github.com/gagliardetto/solana-go" "github.com/pkg/errors" "github.com/zeta-chain/zetacore/pkg/chains" + zetaos "github.com/zeta-chain/zetacore/pkg/os" ) const ( @@ -22,6 +23,26 @@ type RelayerKey struct { PrivateKey string `json:"private_key"` } +// ResolveAddress returns the network name and address of the relayer key +func (rk RelayerKey) ResolveAddress(network int32) (string, string, error) { + // get network name + networkName, found := chains.GetNetworkName(network) + if !found { + return "", "", errors.Errorf("network name not found for network %d", network) + } + + switch chains.Network(network) { + case chains.Network_solana: + privKey, err := solana.PrivateKeyFromBase58(rk.PrivateKey) + if err != nil { + return "", "", errors.Wrap(err, "unable to construct solana private key") + } + return networkName, privKey.PublicKey().String(), nil + default: + return "", "", errors.Errorf("cannot derive relayer address for unsupported network %d", network) + } +} + // LoadRelayerKey loads a relayer key from given path and chain func LoadRelayerKey(keyPath string, chain chains.Chain) (RelayerKey, error) { // determine relayer key file name based on chain @@ -44,18 +65,23 @@ func LoadRelayerKey(keyPath string, chain chains.Chain) (RelayerKey, error) { // ReadRelayerKeyFromFile reads the relayer key file and returns the key func ReadRelayerKeyFromFile(fileName string) (RelayerKey, error) { - fileName = "/root/.zetacored/relayer-keys/solana.json" - fmt.Println("Reading relayer key from file: ", fileName) - file, err := os.Open(fileName) + // expand home directory in the file path if it exists + fileNameFull, err := zetaos.ExpandHomeDir(fileName) if err != nil { - return RelayerKey{}, errors.Wrapf(err, "unable to open relayer key file: %s", fileName) + return RelayerKey{}, errors.Wrapf(err, "ExpandHome failed for file: %s", fileName) + } + + // open the file + file, err := os.Open(fileNameFull) + if err != nil { + return RelayerKey{}, errors.Wrapf(err, "unable to open relayer key file: %s", fileNameFull) } defer file.Close() // read the file contents fileData, err := io.ReadAll(file) if err != nil { - return RelayerKey{}, errors.Wrapf(err, "unable to read relayer key data: %s", fileName) + return RelayerKey{}, errors.Wrapf(err, "unable to read relayer key data: %s", fileNameFull) } // unmarshal the JSON data into the struct @@ -67,3 +93,20 @@ func ReadRelayerKeyFromFile(fileName string) (RelayerKey, error) { return key, nil } + +// GetRelayerKeyFileByNetwork returns the relayer key file name based on network +func GetRelayerKeyFileByNetwork(network int32) (string, error) { + // get network name + networkName, found := chains.GetNetworkName(network) + if !found { + return "", errors.Errorf("network name not found for network %d", network) + } + + // return file name for supported networks only + switch chains.Network(network) { + case chains.Network_solana: + return networkName + ".json", nil + default: + return "", errors.Errorf("network %d does not support relayer key", network) + } +}