Skip to content

Commit

Permalink
Fetch fee policies with a single query
Browse files Browse the repository at this point in the history
  • Loading branch information
squadgazzz committed Jul 16, 2024
1 parent 74934a1 commit e5a33c8
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 55 deletions.
99 changes: 70 additions & 29 deletions crates/database/src/fee_policies.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use {
crate::{auction::AuctionId, OrderUid},
sqlx::{PgConnection, QueryBuilder},
std::collections::HashMap,
};

#[derive(Debug, Clone, PartialEq, sqlx::FromRow)]
Expand Down Expand Up @@ -54,22 +55,38 @@ pub async fn insert_batch(

pub async fn fetch(
ex: &mut PgConnection,
auction_id: AuctionId,
order_uid: OrderUid,
) -> Result<Vec<FeePolicy>, sqlx::Error> {
const QUERY: &str = r#"
SELECT * FROM fee_policies
WHERE auction_id = $1 AND order_uid = $2
ORDER BY application_order
"#;
let rows = sqlx::query_as::<_, FeePolicy>(QUERY)
.bind(auction_id)
.bind(order_uid)
.fetch_all(ex)
.await?
.into_iter()
.collect();
Ok(rows)
keys_filter: &[(AuctionId, OrderUid)],
) -> Result<HashMap<(AuctionId, OrderUid), Vec<FeePolicy>>, sqlx::Error> {
if keys_filter.is_empty() {
return Ok(HashMap::new());
}

let mut query_builder = QueryBuilder::new("SELECT * FROM fee_policies WHERE ");
for (i, (auction_id, order_uid)) in keys_filter.iter().enumerate() {
if i > 0 {
query_builder.push(" OR ");
}
query_builder
.push("(")
.push("auction_id = ")
.push_bind(auction_id)
.push(" AND ")
.push("order_uid = ")
.push_bind(order_uid)
.push(")");
}

query_builder.push(" ORDER BY auction_id, order_uid, application_order");

let query = query_builder.build_query_as::<FeePolicy>();
let rows = query.fetch_all(ex).await?;
let mut result: HashMap<(AuctionId, OrderUid), Vec<FeePolicy>> = HashMap::new();
for row in rows {
let key = (row.auction_id, row.order_uid);
result.entry(key).or_default().push(row);
}

Ok(result)
}

#[cfg(test)]
Expand All @@ -84,12 +101,13 @@ mod tests {
crate::clear_DANGER_(&mut db).await.unwrap();

// same primary key for all fee policies
let (auction_id, order_uid) = (1, ByteArray([1; 56]));
let (auction_id_a, order_uid_a) = (1, ByteArray([1; 56]));
let (auction_id_b, order_uid_b) = (2, ByteArray([2; 56]));

// surplus fee policy without caps
let fee_policy_1 = FeePolicy {
auction_id,
order_uid,
auction_id: auction_id_a,
order_uid: order_uid_a,
kind: FeePolicyKind::Surplus,
surplus_factor: Some(0.1),
surplus_max_volume_factor: Some(0.99999),
Expand All @@ -99,8 +117,8 @@ mod tests {
};
// surplus fee policy with caps
let fee_policy_2 = FeePolicy {
auction_id,
order_uid,
auction_id: auction_id_b,
order_uid: order_uid_b,
kind: FeePolicyKind::Surplus,
surplus_factor: Some(0.2),
surplus_max_volume_factor: Some(0.05),
Expand All @@ -110,8 +128,8 @@ mod tests {
};
// volume based fee policy
let fee_policy_3 = FeePolicy {
auction_id,
order_uid,
auction_id: auction_id_b,
order_uid: order_uid_b,
kind: FeePolicyKind::Volume,
surplus_factor: None,
surplus_max_volume_factor: None,
Expand All @@ -121,8 +139,8 @@ mod tests {
};
// price improvement fee policy
let fee_policy_4 = FeePolicy {
auction_id,
order_uid,
auction_id: auction_id_a,
order_uid: order_uid_a,
kind: FeePolicyKind::PriceImprovement,
surplus_factor: None,
surplus_max_volume_factor: None,
Expand All @@ -131,11 +149,34 @@ mod tests {
price_improvement_max_volume_factor: Some(0.99999),
};

let fee_policies = vec![fee_policy_1, fee_policy_2, fee_policy_3, fee_policy_4];

let fee_policies = vec![
fee_policy_1.clone(),
fee_policy_2.clone(),
fee_policy_3.clone(),
fee_policy_4.clone(),
];
insert_batch(&mut db, fee_policies.clone()).await.unwrap();

let output = fetch(&mut db, 1, order_uid).await.unwrap();
assert_eq!(output, fee_policies);
let mut expected = HashMap::new();
expected.insert(
(auction_id_a, order_uid_a),
vec![fee_policy_1, fee_policy_4],
);
let output = fetch(&mut db, &[(auction_id_a, order_uid_a)])
.await
.unwrap();
assert_eq!(output, expected);

expected.insert(
(auction_id_b, order_uid_b),
vec![fee_policy_2, fee_policy_3],
);
let output = fetch(
&mut db,
&[(auction_id_a, order_uid_a), (auction_id_b, order_uid_b)],
)
.await
.unwrap();
assert_eq!(output, expected);
}
}
20 changes: 13 additions & 7 deletions crates/orderbook/src/database/fee_policies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,33 @@ use {
database::{auction::AuctionId, OrderUid},
model::fee_policy::{FeePolicy, Quote},
number::conversions::big_decimal_to_u256,
std::collections::HashMap,
};

impl super::Postgres {
pub async fn fee_policies(
&self,
auction_id: AuctionId,
order_uid: OrderUid,
quote: Option<&database::orders::Quote>,
) -> anyhow::Result<Vec<FeePolicy>> {
keys_filter: &[(AuctionId, OrderUid)],
quotes: HashMap<OrderUid, database::orders::Quote>,
) -> anyhow::Result<HashMap<(AuctionId, OrderUid), Vec<FeePolicy>>> {
let mut ex = self.pool.acquire().await?;

let _timer = super::Metrics::get()
.database_queries
.with_label_values(&["fee_policies"])
.start_timer();

let fee_policies = database::fee_policies::fetch(&mut ex, auction_id, order_uid).await?;
let fee_policies = database::fee_policies::fetch(&mut ex, keys_filter).await?;
fee_policies
.into_iter()
.map(|db_fee_policy| fee_policy_from(db_fee_policy, quote, order_uid))
.collect::<Result<Vec<_>, _>>()
.map(|((auction_id, order_uid), policies)| {
policies
.into_iter()
.map(|policy| fee_policy_from(policy, quotes.get(&order_uid), order_uid))
.collect::<anyhow::Result<Vec<_>>>()
.map(|policies| ((auction_id, order_uid), policies))
})
.collect::<anyhow::Result<HashMap<_, _>>>()
}
}

Expand Down
44 changes: 25 additions & 19 deletions crates/orderbook/src/database/trades.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use {
anyhow::{Context, Result},
database::{byte_array::ByteArray, trades::TradesQueryRow},
ethcontract::H160,
futures::{future::try_join_all, stream::TryStreamExt},
futures::stream::TryStreamExt,
model::{fee_policy::FeePolicy, order::OrderUid, trade::Trade},
number::conversions::big_decimal_to_big_uint,
primitive_types::H256,
Expand Down Expand Up @@ -53,24 +53,30 @@ impl TradeRetrieving for Postgres {
.collect::<HashMap<_, _>>();
timer.stop_and_record();

try_join_all(
trades
.into_iter()
.map(|trade| {
let quote = quotes.get(&trade.order_uid);
async move {
match trade.auction_id {
Some(auction_id) => {
self.fee_policies(auction_id, trade.order_uid, quote).await
}
None => Ok(vec![]),
}
.and_then(|fee_policies| trade_from(trade, fee_policies))
}
})
.collect::<Vec<_>>(),
)
.await
let auction_order_uids = trades
.iter()
.filter_map(|t| t.auction_id.map(|auction_id| (auction_id, t.order_uid)))
.collect::<Vec<_>>();
let fee_policies = self
.fee_policies(auction_order_uids.as_slice(), quotes)
.await?;

trades
.into_iter()
.map(|trade| {
let fee_policies = trade
.auction_id
.into_iter()
.flat_map(|auction_id| {
fee_policies
.get(&(auction_id, trade.order_uid))
.cloned()
.unwrap_or_default()
})
.collect();
trade_from(trade, fee_policies)
})
.collect::<Result<Vec<_>>>()
}
}

Expand Down

0 comments on commit e5a33c8

Please sign in to comment.