From 4342402eb9fa1526d93bd86b979d9f412319f763 Mon Sep 17 00:00:00 2001 From: h3rt <94856309+SecretSaturn@users.noreply.github.com> Date: Sat, 13 Jul 2024 16:57:02 +0200 Subject: [PATCH] CPI implementation WIP & RNG contract fix --- .../programs/solana-gateway/src/errors.rs | 10 ++- .../programs/solana-gateway/src/lib.rs | 80 +++++++++++++------ TNLS-Relayers/relayer.py | 2 - TNLS-Relayers/sol_interface.py | 24 +++--- TNLS-Samples/RNG/src/contract.rs | 5 +- 5 files changed, 77 insertions(+), 44 deletions(-) diff --git a/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/errors.rs b/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/errors.rs index c575468..0017073 100644 --- a/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/errors.rs +++ b/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/errors.rs @@ -24,8 +24,14 @@ pub enum TaskError { InvalidIndex, #[msg("Task Id already pruned")] TaskIdAlreadyPruned, - #[msg("CallbackAddressesInvalid")] - CallbackAddressesInvalid + #[msg("Callback Addresses are invalid")] + InvalidCallbackAddresses, + #[msg("Borsh Data Serialization failed")] + BorshDataSerializationFailed, + #[msg("Invalid Callback Selector")] + InvalidCallbackSelector, + #[msg("MissingRequiredSignature")] + MissingRequiredSignature } #[error_code] diff --git a/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/lib.rs b/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/lib.rs index faa0348..776214a 100644 --- a/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/lib.rs +++ b/TNLS-Gateways/solana-gateway/programs/solana-gateway/src/lib.rs @@ -9,8 +9,8 @@ use anchor_lang::{ system_program::{transfer, Transfer}, }; use base64::{engine::general_purpose::STANDARD, Engine}; -use sha3::{Digest, Keccak256}; use hex::decode; +use sha3::{Digest, Keccak256}; pub mod errors; use crate::errors::{GatewayError, TaskError}; @@ -22,14 +22,14 @@ declare_id!("5LWZAN7ZFE3Rmg4MdjqNTRkSbMxthyG8ouSa3cfn3R6V"); const TASK_DESTINATION_NETWORK: &str = "pulsar-3"; const CHAIN_ID: &str = "SolanaDevNet"; const SECRET_GATEWAY_PUBKEY: &str = - "0x047a267c6be1157040bd19b893a1fd96266e683da46f00b4ab3a959662aa31c191f2a8a9b17636a0a3e53072b6f102b80452a66ccd7e344fdc8a393124da979bd9"; + "0x04f0c3e600c7f7b3c483debe8f98a839c2d93230d8f857b3c298dc8763c208afcd62dcb34c9306302bf790d8c669674a57defa44c6a95b183d94f2e645526ffe5e"; const SEED: &[u8] = b"gateway_state"; #[program] mod solana_gateway { use super::*; - + pub fn initialize(ctx: Context, bump: u8) -> Result<()> { let gateway_state = &mut ctx.accounts.gateway_state; @@ -120,7 +120,8 @@ mod solana_gateway { transfer(cpi_context, estimated_price)?; //Hash the payload - let generated_payload_hash = solana_program::keccak::hash(&execution_info.payload).to_bytes(); + let generated_payload_hash = + solana_program::keccak::hash(&execution_info.payload).to_bytes(); // Persist the task let task = Task { @@ -210,7 +211,7 @@ mod solana_gateway { // Check if the task is already completed require!(!task.completed, TaskError::TaskAlreadyCompleted); - // Concatenate packet data elements, + // Concatenate packet data elements, // use saved in contract payload_hash to verify that the payload hash wasn't manipulated let data = [ source_network.as_bytes(), @@ -269,10 +270,18 @@ mod solana_gateway { let callback_data = CallbackData { task_id: task_id, - result: post_execution_info.result, + result: post_execution_info.result.clone(), }; - let borsh_data = callback_data.try_to_vec().unwrap(); + let borsh_data = callback_data + .try_to_vec() + .map_err(|_| TaskError::BorshDataSerializationFailed)?; + + // Verify that the recovered public key matches the expected public key + require!( + post_execution_info.callback_selector.len() == 40, + TaskError::InvalidCallbackSelector + ); // Extract and concatenate the program ID and function identifier let program_id_bytes = &post_execution_info.callback_selector[0..32]; @@ -284,34 +293,57 @@ mod solana_gateway { callback_data.extend_from_slice(&borsh_data); // Concatenate all addresses that will be accessed - let callback_address_bytes = &post_execution_info.callback_address; + let callback_address_bytes = post_execution_info.callback_address.clone(); require!( callback_address_bytes.len() % 32 == 0, - TaskError::CallbackAddressesInvalid + TaskError::InvalidCallbackAddresses ); - let callback_addresses: Vec = callback_address_bytes - .chunks(32) // Assuming each address is 32 bytes - .map(|address| { - AccountMeta::new(Pubkey::new(address), false) - }) + for chunk in callback_address_bytes.chunks(32) { + let pubkey = Pubkey::try_from(chunk).expect("Invalid callback pubkey"); + if ctx.remaining_accounts.iter().find(|account| account.key == &pubkey).is_none() { + return Err(TaskError::MissingRequiredSignature.into()); + } + } + + // Map callback_address_bytes to AccountInfo + let callback_addresses: Vec = callback_address_bytes + .chunks(32) + .map(|address| { + let pubkey = Pubkey::try_from(address).expect("Invalid callback pubkey"); + ctx.remaining_accounts + .iter() + .find(|account| account.key == &pubkey) + .expect("Callback account not found") + .clone() + }) + .collect(); + + let system_program = ctx.accounts.system_program.to_account_info(); + + // Collect the callback addresses into a vector + let mut callback_account_metas: Vec = callback_addresses.iter() + .map(|account| AccountMeta::new(*account.key, account.is_signer)) .collect(); + // Add the system_program account to the vector + callback_account_metas.push(AccountMeta::new_readonly(*system_program.key, false)); + // Execute the callback with signed seeds let callback_result = invoke_signed( &Instruction { - program_id: Pubkey::new(program_id_bytes), - accounts: callback_addresses, + program_id: Pubkey::try_from(program_id_bytes).expect("Invalid Pubkey"), + accounts: callback_account_metas, data: callback_data, }, - &[], + &callback_addresses, &[&[SEED.as_ref(), &[bump]]], ); let task_completed = TaskCompleted { task_id, - callback_successful: true, //callback_result.is_ok() + callback_successful: callback_result.is_ok(), }; msg!(&format!( @@ -322,18 +354,16 @@ mod solana_gateway { Ok(()) } - pub fn callback_test ( - ctx: Context, - task_id: u64, - result: String, - ) -> Result<()> { - + pub fn callback_test(ctx: Context, task_id: u64, result: String) -> Result<()> { + msg!("Callback invoked with task_id: {} and result: {}", task_id, result); + Ok(()) } } + #[derive(Accounts)] pub struct CallbackTest<'info> { #[account(mut)] - pub user: Signer<'info>, + pub secretpath_gateway: Signer<'info>, pub system_program: Program<'info, System>, } diff --git a/TNLS-Relayers/relayer.py b/TNLS-Relayers/relayer.py index a26519c..c575fc3 100644 --- a/TNLS-Relayers/relayer.py +++ b/TNLS-Relayers/relayer.py @@ -79,8 +79,6 @@ def process_chain(name): prev_height = curr_height - 1 def fetch_transactions(block_num): - block_num = 309524872 - sleep(0.2) transactions = chain_interface.get_transactions(contract_interface, height=block_num) tasks_tmp = [] for transaction in transactions: diff --git a/TNLS-Relayers/sol_interface.py b/TNLS-Relayers/sol_interface.py index 2a44378..dd194a6 100644 --- a/TNLS-Relayers/sol_interface.py +++ b/TNLS-Relayers/sol_interface.py @@ -108,13 +108,16 @@ def get_transactions(self, contract_interface, height): """ Get transactions for a given address. """ + jump = 10 + if height % jump != 0: + return [] filtered_transactions = [] try: - response = self.provider.get_signatures_for_address(account=contract_interface.address, limit=1, + response = self.provider.get_signatures_for_address(account=contract_interface.address, limit=10, commitment=Confirmed) if response.value: # Filter transactions by slot height - filtered_transactions = [tx.signature for tx in response.value if tx.slot == height] + filtered_transactions = [tx.signature for tx in response.value if tx.slot >= height-jump] else: return [] except Exception as e: @@ -182,33 +185,32 @@ def call_function(self, function_name, *args): Create a transaction with the given instructions and signers. """ # Create context - keys: list[AccountMeta] = [ + accounts: list[AccountMeta] = [ AccountMeta(pubkey=self.address, is_signer=False, is_writable=True), AccountMeta(pubkey=self.interface.address, is_signer=True, is_writable=True), AccountMeta(pubkey=SYS_PROGRAM_ID, is_signer=False, is_writable=False), ] - callback_address_bytes = bytes.fromhex(args[2][2][2:]) + if len(args) == 1: + args = json.loads(args[0]) - # Ensure the byte data length is a multiple of 32 + # Ensure the callback_address_bytes length is a multiple of 32 + callback_address_bytes = bytes.fromhex(args[2][2][2:]) if len(callback_address_bytes) % 32 != 0: raise ValueError("callback_address_bytes length is not a multiple of 32") - # Step 1-3: Process the addresses callback_addresses: List[AccountMeta] = [ AccountMeta(pubkey=Pubkey(callback_address_bytes[i:i + 32]), is_signer=False, is_writable=True) for i in range(0, len(callback_address_bytes), 32) ] + print(callback_addresses) if callback_addresses is not None: - keys += callback_addresses + accounts += callback_addresses # The Identifier of the post execution function identifier = bytes([52, 46, 67, 194, 153, 197, 69, 168]) - if len(args) == 1: - args = json.loads(args[0]) - print(args) encoded_args = PostExecution.layout.build( { "task_id": args[0], @@ -225,7 +227,7 @@ def call_function(self, function_name, *args): } ) data = identifier + encoded_args - tx = Instruction(self.program_id, data, keys) + tx = Instruction(program_id=self.program_id, data=data, accounts=accounts) submitted_txn = self.interface.sign_and_send_transaction(tx) return submitted_txn diff --git a/TNLS-Samples/RNG/src/contract.rs b/TNLS-Samples/RNG/src/contract.rs index a858cea..476d537 100644 --- a/TNLS-Samples/RNG/src/contract.rs +++ b/TNLS-Samples/RNG/src/contract.rs @@ -83,10 +83,7 @@ fn try_handle( let input_hash_safe = sha_256(&[msg.input_values.as_bytes(), msg.task.task_id.as_bytes(),&[0u8]].concat()); let input_hash_unsafe = sha_256(&[msg.input_values.as_bytes(), msg.task.task_id.as_bytes(),&[1u8]].concat()); - if msg.input_hash.as_slice() != input_hash_safe.as_slice() { - if msg.input_hash.as_slice() == input_hash_unsafe.as_slice() { - return Err(StdError::generic_err("Payload was marked as unsafe, not executing")); - } + if msg.input_hash.as_slice() != input_hash_safe.as_slice() || msg.input_hash.as_slice() != input_hash_unsafe.as_slice(){ return Err(StdError::generic_err("Safe input hash does not match provided input hash")); } // determine which function to call based on the included handle