Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sequencer)!: fix TOCTOU issues by merging check and execution #1332

Merged
merged 13 commits into from
Aug 20, 2024
Merged
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion crates/astria-sequencer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ borsh = { version = "1", features = ["derive"] }
cnidarium = { git = "https://github.com/penumbra-zone/penumbra.git", tag = "v0.78.0", features = [
"metrics",
] }
cnidarium-component = { git = "https://github.com/penumbra-zone/penumbra.git", tag = "v0.78.0" }
ibc-proto = { version = "0.41.0", features = ["server"] }
matchit = "0.7.2"
priority-queue = "2.0.2"
Expand Down
195 changes: 104 additions & 91 deletions crates/astria-sequencer/src/accounts/action.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,125 @@ use anyhow::{
Result,
};
use astria_core::{
primitive::v1::Address,
primitive::v1::ADDRESS_LEN,
protocol::transaction::v1alpha1::action::TransferAction,
Protobuf,
};
use tracing::instrument;
use cnidarium::{
StateRead,
StateWrite,
};

use super::AddressBytes;
use crate::{
accounts::{
self,
StateReadExt as _,
StateWriteExt as _,
},
address::StateReadExt as _,
app::ActionHandler,
assets::{
StateReadExt as _,
StateWriteExt as _,
},
address,
assets,
bridge::StateReadExt as _,
transaction::action_handler::ActionHandler,
transaction::StateReadExt as _,
};

pub(crate) async fn transfer_check_stateful<S>(
#[async_trait::async_trait]
impl ActionHandler for TransferAction {
type CheckStatelessContext = ();

async fn check_stateless(&self, _context: Self::CheckStatelessContext) -> Result<()> {
Ok(())
}

async fn check_and_execute<S: StateWrite>(&self, state: S) -> Result<()> {
let from = state
.get_current_source()
.expect("transaction source must be present in state when executing an action")
.address_bytes();

ensure!(
state
.get_bridge_account_rollup_id(from)
.await
.context("failed to get bridge account rollup id")?
.is_none(),
"cannot transfer out of bridge account; BridgeUnlock must be used",
);

check_transfer(self, from, &state).await?;
execute_transfer(self, from, state).await?;

Ok(())
}
}

pub(crate) async fn execute_transfer<S: StateWrite>(
action: &TransferAction,
from: [u8; ADDRESS_LEN],
mut state: S,
) -> anyhow::Result<()> {
let fee = state
.get_transfer_base_fee()
.await
.context("failed to get transfer base fee")?;
state
.get_and_increase_block_fees(&action.fee_asset, fee, TransferAction::full_name())
.await
.context("failed to add to block fees")?;

// if fee payment asset is same asset as transfer asset, deduct fee
// from same balance as asset transferred
if action.asset.to_ibc_prefixed() == action.fee_asset.to_ibc_prefixed() {
// check_stateful should have already checked this arithmetic
let payment_amount = action
.amount
.checked_add(fee)
.expect("transfer amount plus fee should not overflow");

state
.decrease_balance(from, &action.asset, payment_amount)
.await
.context("failed decreasing `from` account balance")?;
state
.increase_balance(action.to, &action.asset, action.amount)
.await
.context("failed increasing `to` account balance")?;
} else {
// otherwise, just transfer the transfer asset and deduct fee from fee asset balance
// later
state
.decrease_balance(from, &action.asset, action.amount)
.await
.context("failed decreasing `from` account balance")?;
state
.increase_balance(action.to, &action.asset, action.amount)
.await
.context("failed increasing `to` account balance")?;

// deduct fee from fee asset balance
state
.decrease_balance(from, &action.fee_asset, fee)
.await
.context("failed decreasing `from` account balance for fee payment")?;
}
Ok(())
}

pub(crate) async fn check_transfer<S, TAddress>(
action: &TransferAction,
from: TAddress,
state: &S,
from: Address,
) -> Result<()>
where
S: accounts::StateReadExt + assets::StateReadExt + 'static,
S: StateRead,
TAddress: AddressBytes,
{
state.ensure_base_prefix(&action.to).await.context(
"failed ensuring that the destination address matches the permitted base prefix",
)?;
ensure!(
state
.is_allowed_fee_asset(&action.fee_asset)
Expand All @@ -44,7 +138,7 @@ where
let transfer_asset = action.asset.clone();

let from_fee_balance = state
.get_account_balance(from, &action.fee_asset)
.get_account_balance(&from, &action.fee_asset)
.await
.context("failed getting `from` account balance for fee payment")?;

Expand Down Expand Up @@ -80,84 +174,3 @@ where

Ok(())
}

#[async_trait::async_trait]
impl ActionHandler for TransferAction {
async fn check_stateless(&self) -> Result<()> {
Ok(())
}

async fn check_stateful<S>(&self, state: &S, from: Address) -> Result<()>
where
S: accounts::StateReadExt + address::StateReadExt + 'static,
{
state.ensure_base_prefix(&self.to).await.context(
"failed ensuring that the destination address matches the permitted base prefix",
)?;
ensure!(
state
.get_bridge_account_rollup_id(&from)
.await
.context("failed to get bridge account rollup id")?
.is_none(),
"cannot transfer out of bridge account; BridgeUnlock must be used",
);

transfer_check_stateful(self, state, from)
.await
.context("stateful transfer check failed")
}

#[instrument(skip_all)]
async fn execute<S>(&self, state: &mut S, from: Address) -> Result<()>
where
S: accounts::StateWriteExt + assets::StateWriteExt,
{
let fee = state
.get_transfer_base_fee()
.await
.context("failed to get transfer base fee")?;
state
.get_and_increase_block_fees(&self.fee_asset, fee, Self::full_name())
.await
.context("failed to add to block fees")?;

// if fee payment asset is same asset as transfer asset, deduct fee
// from same balance as asset transferred
if self.asset.to_ibc_prefixed() == self.fee_asset.to_ibc_prefixed() {
// check_stateful should have already checked this arithmetic
let payment_amount = self
.amount
.checked_add(fee)
.expect("transfer amount plus fee should not overflow");

state
.decrease_balance(from, &self.asset, payment_amount)
.await
.context("failed decreasing `from` account balance")?;
state
.increase_balance(self.to, &self.asset, self.amount)
.await
.context("failed increasing `to` account balance")?;
} else {
// otherwise, just transfer the transfer asset and deduct fee from fee asset balance
// later
state
.decrease_balance(from, &self.asset, self.amount)
.await
.context("failed decreasing `from` account balance")?;
state
.increase_balance(self.to, &self.asset, self.amount)
.await
.context("failed increasing `to` account balance")?;

// deduct fee from fee asset balance
state
.decrease_balance(from, &self.fee_asset, fee)
.await
.context("failed decreasing `from` account balance for fee payment")?;
}

Ok(())
}
}
54 changes: 54 additions & 0 deletions crates/astria-sequencer/src/accounts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,61 @@ pub(crate) mod component;
pub(crate) mod query;
mod state_ext;

use astria_core::{
crypto::{
SigningKey,
VerificationKey,
},
primitive::v1::{
Address,
ADDRESS_LEN,
},
protocol::transaction::v1alpha1::SignedTransaction,
};
pub(crate) use state_ext::{
StateReadExt,
StateWriteExt,
};

pub(crate) trait AddressBytes: Send + Sync {
fn address_bytes(&self) -> [u8; ADDRESS_LEN];
}

impl AddressBytes for Address {
fn address_bytes(&self) -> [u8; ADDRESS_LEN] {
self.bytes()
}
}

impl AddressBytes for [u8; ADDRESS_LEN] {
fn address_bytes(&self) -> [u8; ADDRESS_LEN] {
*self
}
}

impl AddressBytes for SignedTransaction {
fn address_bytes(&self) -> [u8; ADDRESS_LEN] {
self.address_bytes()
}
}

impl AddressBytes for SigningKey {
fn address_bytes(&self) -> [u8; ADDRESS_LEN] {
self.address_bytes()
}
}

impl AddressBytes for VerificationKey {
fn address_bytes(&self) -> [u8; ADDRESS_LEN] {
self.address_bytes()
}
}

impl<'a, T> AddressBytes for &'a T
where
T: AddressBytes,
{
fn address_bytes(&self) -> [u8; ADDRESS_LEN] {
(*self).address_bytes()
}
}
Loading
Loading