Skip to content

Commit

Permalink
dex: refactor internal position update logic (#4188)
Browse files Browse the repository at this point in the history
## Describe your changes

This PR:
- break out indexing methods from the `Inner` position manager trait
into submodules with crate visibility
- refactor `update_position` to delegate indexing checks to
`update_*_index` methods
- add a redundant guard against invalid transitions in
`PositionManager::update_position`
- streamline the base liquidity index, fixing a bug with double counting
closed positions

This PR contain changes to critical parts of the DEX engine internals.

## Checklist before requesting a review

- [x] If this code contains consensus-breaking changes, I have added the
"consensus-breaking" label. Otherwise, I declare my belief that there
are not consensus-breaking changes, for the following reason:

> This is technically consensus breaking because we fix a bug in the
base liquidity index logic, which could trickle down into different DEX
execution in some rare cases.
  • Loading branch information
erwanor authored Apr 16, 2024
1 parent 9285efc commit c8c980d
Show file tree
Hide file tree
Showing 9 changed files with 486 additions and 332 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl<T: StateWrite + ?Sized> ValueCircuitBreaker for T {}
mod tests {
use std::sync::Arc;

use crate::component::position_manager::Inner as _;
use crate::component::position_manager::price_index::PositionByPriceIndex;
use crate::component::router::HandleBatchSwaps as _;
use crate::component::{StateReadExt as _, StateWriteExt as _};
use crate::lp::plan::PositionWithdrawPlan;
Expand Down Expand Up @@ -225,11 +225,9 @@ mod tests {
let id = buy_1.id();

let position = buy_1;
state_tx.index_position_by_price(&position, &position.id());
state_tx
.update_available_liquidity(&None, &position)
.await
.expect("able to update liquidity");
.update_position_by_price_index(&None, &position, &position.id())
.expect("can update price index");
state_tx.put(state_key::position_by_id(&id), position);

// Now there's a position in the state, but the circuit breaker is not aware of it.
Expand Down
5 changes: 2 additions & 3 deletions crates/core/component/dex/src/component/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ mod arb;
pub(crate) mod circuit_breaker;
mod dex;
mod flow;
pub(crate) mod position_counter;
pub(crate) mod position_manager;
mod position_manager;
mod swap_manager;

pub use self::metrics::register_metrics;
pub use arb::Arbitrage;
pub(crate) use arb::Arbitrage;
pub use circuit_breaker::ExecutionCircuitBreaker;
pub(crate) use circuit_breaker::ValueCircuitBreaker;
pub use dex::{Dex, StateReadExt, StateWriteExt};
Expand Down
339 changes: 56 additions & 283 deletions crates/core/component/dex/src/component/position_manager.rs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
use anyhow::Result;
use cnidarium::StateWrite;
use penumbra_num::Amount;
use position::State::*;

use crate::lp::position::{self, Position};
use crate::state_key::engine;
use crate::DirectedTradingPair;
use penumbra_proto::{StateReadProto, StateWriteProto};

pub(crate) trait AssetByLiquidityIndex: StateWrite {
/// Update the base liquidity index, used by the DEX engine during path search.
///
/// # Overview
/// Given a directed trading pair `A -> B`, the index tracks the amount of
/// liquidity available to convert the quote asset B, into a base asset A.
///
/// # Index schema
/// The liquidity index schema is as follow:
/// - A primary index that maps a "start" asset A (aka. base asset)
/// to an "end" asset B (aka. quote asset) ordered by the amount of
/// liquidity available for B -> A (not a typo).
/// - An auxilliary index that maps a directed trading pair `A -> B`
/// to the aggregate liquidity for B -> A (used in the primary composite key)
///
/// # Diagram
///
/// Liquidity index:
/// For an asset `A`, surface asset
/// `B` with the best liquidity
/// score.
/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐
///
/// ┌──┐ ▼ ┌─────────┐ │
/// ▲ │ │ ┌──────────────────┐ │ │
/// │ │ ─┼───▶│{asset_A}{agg_liq}│──▶│{asset_B}│ │
/// │ ├──┤ └──────────────────┘ │ │
/// sorted │ │ └─────────┘ │
/// by agg │ │
/// liq ├──┤ │
/// │ │ │ used in the
/// │ ├──┤ composite
/// │ │ │ key
/// │ │ │ Auxiliary look-up index: │
/// │ │ │ "Find the aggregate liquidity
/// │ │ │ per directed trading pair" │
/// │ │ │ ┌───────┐ ┌─────────┐
/// │ │ │ ├───────┤ ┌──────────────────┐ │ │
/// │ │ │ │ ────┼─▶│{asset_A}{asset_B}│────▶│{agg_liq}│
/// │ ├──┤ ├───────┤ └──────────────────┘ │ │
/// │ │ │ ├───────┤ └─────────┘
/// │ │ │ ├───────┤
/// │ │ │ ├───────┤
/// │ ├──┤ └───────┘
/// │ │ │
/// │ │ │
/// │ └──┘
async fn update_asset_by_base_liquidity_index(
&mut self,
prev_state: &Option<Position>,
new_state: &Position,
id: &position::Id,
) -> Result<()> {
// We need to reconstruct the position's previous contribution and compute
// its new contribution to the index. We do this for each asset in the pair
// and short-circuit if all contributions are zero.
let canonical_pair = new_state.phi.pair;
let pair_ab = DirectedTradingPair::new(canonical_pair.asset_1(), canonical_pair.asset_2());

// We reconstruct the position's *previous* contribution so that we can deduct them later:
let (prev_a, prev_b) = match prev_state {
// The position was just created, so its previous contributions are zero.
None => (Amount::zero(), Amount::zero()),
Some(prev) => match prev.state {
// The position was previously closed or withdrawn, so its previous contributions are zero.
Closed | Withdrawn { sequence: _ } => (Amount::zero(), Amount::zero()),
// The position's previous contributions are the reserves for the start and end assets.
_ => (
prev.reserves_for(pair_ab.start)
.expect("asset ids match for start"),
prev.reserves_for(pair_ab.end)
.expect("asset ids match for end"),
),
},
};

// For each asset, we compute the new position's contribution to the index:
let (new_a, new_b) = if matches!(new_state.state, Closed | Withdrawn { sequence: _ }) {
// The position is being closed or withdrawn, so its new contributions are zero.
// Note a withdrawn position MUST have zero reserves, so hardcoding this is extra.
(Amount::zero(), Amount::zero())
} else {
(
// The new amount of asset A:
new_state
.reserves_for(pair_ab.start)
.expect("asset ids match for start"),
// The new amount of asset B:
new_state
.reserves_for(pair_ab.end)
.expect("asset ids match for end"),
)
};

// If all contributions are zero, we can skip the update.
// This can happen if we're processing inactive transitions like `Closed -> Withdrawn`.
if prev_a == Amount::zero()
&& new_a == Amount::zero()
&& prev_b == Amount::zero()
&& new_b == Amount::zero()
{
return Ok(());
}

// A -> B
self.update_asset_by_base_liquidity_index_inner(id, pair_ab, prev_a, new_a)
.await?;
// B -> A
self.update_asset_by_base_liquidity_index_inner(id, pair_ab.flip(), prev_b, new_b)
.await?;

Ok(())
}
}

impl<T: StateWrite + ?Sized> AssetByLiquidityIndex for T {}

trait Inner: StateWrite {
async fn update_asset_by_base_liquidity_index_inner(
&mut self,
id: &position::Id,
pair: DirectedTradingPair,
old_contrib: Amount,
new_contrib: Amount,
) -> Result<()> {
let aggregate_key = &engine::routable_assets::lookup_base_liquidity_by_pair(&pair);

let prev_tally: Amount = self
.nonverifiable_get(aggregate_key)
.await?
.unwrap_or_default();

// To compute the new aggregate liquidity, we deduct the old contribution
// and add the new contribution. We use saturating arithmetic defensively.
let new_tally = prev_tally
.saturating_sub(&old_contrib)
.saturating_add(&new_contrib);

// If the update operation is a no-op, we can skip the update and return early.
if prev_tally == new_tally {
tracing::debug!(
?prev_tally,
?pair,
?id,
"skipping routable asset index update"
);
return Ok(());
}

// Update the primary and auxiliary indices:
let old_primary_key = engine::routable_assets::key(&pair.start, prev_tally).to_vec();
// This could make the `StateDelta` more expensive to scan, but this doesn't show on profiles yet.
self.nonverifiable_delete(old_primary_key);

let new_primary_key = engine::routable_assets::key(&pair.start, new_tally).to_vec();
self.nonverifiable_put(new_primary_key, pair.end);
tracing::debug!(?pair, ?new_tally, "base liquidity entry updated");

let auxiliary_key = engine::routable_assets::lookup_base_liquidity_by_pair(&pair).to_vec();
self.nonverifiable_put(auxiliary_key, new_tally);
tracing::debug!(
?pair,
"base liquidity heuristic marked directed pair as routable"
);

Ok(())
}
}

impl<T: StateWrite + ?Sized> Inner for T {}
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
use anyhow::bail;
use async_trait::async_trait;
use cnidarium::StateWrite;
use cnidarium::{StateRead, StateWrite};

use crate::lp::position::{self, Position};
use crate::state_key::engine;
use crate::TradingPair;
use anyhow::Result;

#[async_trait]
pub(crate) trait PositionCounter: StateWrite {
pub(super) trait PositionCounterRead: StateRead {
/// Returns the number of position for a [`TradingPair`].
/// If there were no counter initialized for a given pair, this default to zero.
async fn get_position_count(&self, trading_pair: &TradingPair) -> u16 {
async fn get_position_count(&self, trading_pair: &TradingPair) -> u32 {
let path = engine::counter::num_positions::by_trading_pair(trading_pair);
self.get_position_count_from_key(path).await
}

async fn get_position_count_from_key(&self, path: [u8; 99]) -> u16 {
async fn get_position_count_from_key(&self, path: [u8; 99]) -> u32 {
let Some(raw_count) = self
.nonverifiable_get_raw(&path)
.await
Expand All @@ -24,16 +25,47 @@ pub(crate) trait PositionCounter: StateWrite {
return 0;
};

// This is safe because we only increment the counter via a [`Self::increase_position_counter`].
let raw_count: [u8; 2] = raw_count
// This is safe because we only increment the counter via [`Self::increase_position_counter`].
let raw_count: [u8; 4] = raw_count
.try_into()
.expect("position counter is at most two bytes");
u16::from_be_bytes(raw_count)
u32::from_be_bytes(raw_count)
}
}

impl<T: StateRead + ?Sized> PositionCounterRead for T {}

#[async_trait]
pub(crate) trait PositionCounter: StateWrite {
async fn update_trading_pair_position_counter(
&mut self,
prev_state: &Option<Position>,
new_state: &Position,
_id: &position::Id,
) -> Result<()> {
use position::State::*;
let trading_pair = new_state.phi.pair;
match (prev_state.as_ref().map(|p| p.state), new_state.state) {
// Increment the counter whenever a new position is opened
(None, Opened) => {
let _ = self.increment_position_counter(&trading_pair).await?;
}
// Decrement the counter whenever an opened position is closed
(Some(Opened), Closed) => {
let _ = self.decrement_position_counter(&trading_pair).await?;
}
// Other state transitions don't affect the opened position counter
_ => {}
}
Ok(())
}
}
impl<T: StateWrite + ?Sized> PositionCounter for T {}

trait Inner: StateWrite {
/// Increment the number of position for a [`TradingPair`].
/// Returns the updated total, or an error if overflow occurred.
async fn increment_position_counter(&mut self, trading_pair: &TradingPair) -> Result<u16> {
async fn increment_position_counter(&mut self, trading_pair: &TradingPair) -> Result<u32> {
let path = engine::counter::num_positions::by_trading_pair(trading_pair);
let prev = self.get_position_count_from_key(path).await;

Expand All @@ -46,7 +78,7 @@ pub(crate) trait PositionCounter: StateWrite {

/// Decrement the number of positions for a [`TradingPair`], unless it would underflow.
/// Returns the updated total, or an error if underflow occurred.
async fn decrement_position_counter(&mut self, trading_pair: &TradingPair) -> Result<u16> {
async fn decrement_position_counter(&mut self, trading_pair: &TradingPair) -> Result<u32> {
let path = engine::counter::num_positions::by_trading_pair(trading_pair);
let prev = self.get_position_count_from_key(path).await;

Expand All @@ -57,16 +89,20 @@ pub(crate) trait PositionCounter: StateWrite {
Ok(new_total)
}
}
impl<T: StateWrite + ?Sized> PositionCounter for T {}

impl<T: StateWrite + ?Sized> Inner for T {}

// For some reason, `rust-analyzer` is complaining about used imports.
// Silence the warnings until I find a fix.
#[allow(unused_imports)]
mod tests {
use cnidarium::{StateDelta, TempStorage};
use cnidarium::{StateDelta, StateWrite, TempStorage};
use penumbra_asset::{asset::REGISTRY, Value};

use crate::component::position_counter::PositionCounter;
use crate::component::position_manager::counter::{
Inner, PositionCounter, PositionCounterRead,
};
use crate::state_key::engine;
use crate::TradingPair;

#[tokio::test]
Expand All @@ -78,22 +114,20 @@ mod tests {

let storage = TempStorage::new().await?;
let mut delta = StateDelta::new(storage.latest_snapshot());
let path = engine::counter::num_positions::by_trading_pair(&trading_pair);
// Manually set the counter to the maximum value
delta.nonverifiable_put_raw(path.to_vec(), u32::MAX.to_be_bytes().to_vec());

for i in 0..u16::MAX {
let total = delta.increment_position_counter(&trading_pair).await?;

anyhow::ensure!(
total == i + 1,
"the total amount should be total={}, found={total}",
i + 1
);
}
// Check that the counter is at the maximum value
let total = delta.get_position_count(&trading_pair).await;
assert_eq!(total, u32::MAX);

// Check that we can handle an overflow
assert!(delta
.increment_position_counter(&trading_pair)
.await
.is_err());
assert_eq!(delta.get_position_count(&trading_pair).await, u16::MAX);
assert_eq!(delta.get_position_count(&trading_pair).await, u32::MAX);

Ok(())
}
Expand All @@ -112,7 +146,7 @@ mod tests {
assert!(maybe_total.is_err());

let counter = delta.get_position_count(&trading_pair).await;
assert_eq!(counter, 0u16);
assert_eq!(counter, 0u32);
Ok(())
}
}
Loading

0 comments on commit c8c980d

Please sign in to comment.