Skip to content

Commit

Permalink
add mint (SPL coin address) to the withdraw instr, and the decimals f…
Browse files Browse the repository at this point in the history
…or extra validation
  • Loading branch information
brewmaster012 committed Oct 5, 2024
1 parent 812ec18 commit 4720534
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 12 deletions.
30 changes: 24 additions & 6 deletions programs/protocol-contracts-solana/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use anchor_lang::prelude::*;
use anchor_lang::system_program;
use anchor_spl::token::{transfer, Token, TokenAccount};
use anchor_spl::token::{transfer, Token, TokenAccount, Mint};
use solana_program::keccak::hash;
use solana_program::secp256k1_recover::secp256k1_recover;
use std::mem::size_of;
use anchor_spl::associated_token::get_associated_token_address;

#[error_code]
pub enum Errors {
Expand All @@ -25,12 +26,17 @@ pub enum Errors {
MemoLengthTooShort,
#[msg("DepositPaused")]
DepositPaused,
#[msg("SPLAtaAndMintAddressMismatch")]
SPLAtaAndMintAddressMismatch,

}

declare_id!("ZETAjseVjuFsxdRxo6MmTCvqFwb3ZHUx56Co3vCmGis");

#[program]
pub mod gateway {
use anchor_spl::token::transfer_checked;

use super::*;

pub fn initialize(
Expand Down Expand Up @@ -235,6 +241,7 @@ pub mod gateway {
// concatenated_buffer vec.
pub fn withdraw_spl_token(
ctx: Context<WithdrawSPLToken>,
decimals: u8,
amount: u64,
signature: [u8; 64],
recovery_id: u8,
Expand All @@ -253,7 +260,7 @@ pub mod gateway {
concatenated_buffer.extend_from_slice(&pda.chain_id.to_be_bytes());
concatenated_buffer.extend_from_slice(&nonce.to_be_bytes());
concatenated_buffer.extend_from_slice(&amount.to_be_bytes());
concatenated_buffer.extend_from_slice(&ctx.accounts.from.key().to_bytes());
concatenated_buffer.extend_from_slice(&ctx.accounts.mint_account.key().to_bytes());
concatenated_buffer.extend_from_slice(&ctx.accounts.to.key().to_bytes());
require!(
message_hash == hash(&concatenated_buffer[..]).to_bytes(),
Expand All @@ -267,13 +274,21 @@ pub mod gateway {
return err!(Errors::TSSAuthenticationFailed);
}

let pda_ata = get_associated_token_address(&pda.key(), &ctx.accounts.mint_account.key());
require!(
pda_ata == ctx.accounts.pda_ata.to_account_info().key(),
Errors::SPLAtaAndMintAddressMismatch
);

let token = &ctx.accounts.token_program;
let signer_seeds: &[&[&[u8]]] = &[&[b"meta", &[ctx.bumps.pda]]];


let xfer_ctx = CpiContext::new_with_signer(
token.to_account_info(),
anchor_spl::token::Transfer {
from: ctx.accounts.from.to_account_info(),
anchor_spl::token::TransferChecked {
from: ctx.accounts.pda_ata.to_account_info(),
mint: ctx.accounts.mint_account.to_account_info(),
to: ctx.accounts.to.to_account_info(),
authority: pda.to_account_info(),
},
Expand All @@ -282,7 +297,7 @@ pub mod gateway {

pda.nonce += 1;

transfer(xfer_ctx, amount)?;
transfer_checked(xfer_ctx, amount, decimals)?;
msg!("withdraw spl token successfully");

Ok(())
Expand Down Expand Up @@ -364,7 +379,10 @@ pub struct WithdrawSPLToken<'info> {
pub pda: Account<'info, Pda>,

#[account(mut)]
pub from: Account<'info, TokenAccount>,
pub pda_ata: Account<'info, TokenAccount>, // associated token address of PDA

#[account()]
pub mint_account: Account<'info, Mint>,

#[account(mut)]
pub to: Account<'info, TokenAccount>,
Expand Down
16 changes: 10 additions & 6 deletions tests/protocol-contracts-solana.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ const ec = new EC('secp256k1');
// read private key from hex dump
const keyPair = ec.keyFromPrivate('5b81cdf52ba0766983acf8dd0072904733d92afe4dd3499e83e879b43ccb73e8');

const usdcDecimals = 6;

describe("some tests", () => {
// Configure the client to use the local cluster.
anchor.setProvider(anchor.AnchorProvider.env());
Expand Down Expand Up @@ -86,7 +88,7 @@ describe("some tests", () => {
}),
spl.createInitializeMintInstruction(
mint.publicKey,
6,
usdcDecimals,
wallet.publicKey,
null,
)
Expand Down Expand Up @@ -208,7 +210,7 @@ describe("some tests", () => {
chain_id_bn.toArrayLike(Buffer, 'be', 8),
nonce.toArrayLike(Buffer, 'be', 8),
amount.toArrayLike(Buffer, 'be', 8),
pda_ata.address.toBuffer(),
mint.publicKey.toBuffer(),
wallet_ata.toBuffer(),
]);
const message_hash = keccak256(buffer);
Expand All @@ -219,9 +221,10 @@ describe("some tests", () => {
s.toArrayLike(Buffer, 'be', 32),
]);

await gatewayProgram.methods.withdrawSplToken(amount, Array.from(signatureBuffer), Number(recoveryParam), Array.from(message_hash), nonce)
await gatewayProgram.methods.withdrawSplToken(usdcDecimals,amount, Array.from(signatureBuffer), Number(recoveryParam), Array.from(message_hash), nonce)
.accounts({
from: pda_ata.address,
pdaAta: pda_ata.address,
mintAccount: mint.publicKey,
to: wallet_ata,
}).rpc();

Expand All @@ -230,9 +233,10 @@ describe("some tests", () => {


try {
(await gatewayProgram.methods.withdrawSplToken(new anchor.BN(500_000), Array.from(signatureBuffer), Number(recoveryParam), Array.from(message_hash), nonce)
(await gatewayProgram.methods.withdrawSplToken(usdcDecimals,new anchor.BN(500_000), Array.from(signatureBuffer), Number(recoveryParam), Array.from(message_hash), nonce)
.accounts({
from: pda_ata.address,
pdaAta: pda_ata.address,
mintAccount: mint.publicKey,
to: wallet_ata,
}).rpc());
throw new Error("Expected error not thrown"); // This line will make the test fail if no error is thrown
Expand Down

0 comments on commit 4720534

Please sign in to comment.