Skip to content

Commit

Permalink
feat: use DeserializableTxPacket to standardize crate
Browse files Browse the repository at this point in the history
  • Loading branch information
flame4 committed Oct 14, 2024
1 parent c121119 commit 32ce898
Show file tree
Hide file tree
Showing 7 changed files with 396 additions and 1,509 deletions.
12 changes: 6 additions & 6 deletions core/src/banking_stage/immutable_deserialized_packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ pub enum DeserializedPacketError {

#[derive(Debug, Eq)]
pub struct ImmutableDeserializedPacket {
original_packet: Packet,
transaction: SanitizedVersionedTransaction,
message_hash: Hash,
is_simple_vote: bool,
compute_unit_price: u64,
compute_unit_limit: u32,
pub original_packet: Packet,
pub transaction: SanitizedVersionedTransaction,
pub message_hash: Hash,
pub is_simple_vote: bool,
pub compute_unit_price: u64,
pub compute_unit_limit: u32,
}

impl ImmutableDeserializedPacket {
Expand Down
46 changes: 45 additions & 1 deletion prio-graph-scheduler/src/deserializable_packet.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use ahash::HashSet;
use std::collections::HashSet;
use solana_core::banking_stage::immutable_deserialized_packet::{DeserializedPacketError, ImmutableDeserializedPacket};
use solana_sdk::hash::Hash;
use solana_sdk::message::AddressLoader;
use solana_sdk::packet::Packet;
Expand Down Expand Up @@ -37,3 +38,46 @@ pub trait DeserializableTxPacket: PartialEq + PartialOrd + Eq + Sized {

fn compute_unit_limit(&self) -> u64;
}


/// TODO: migrate to solana_core
impl DeserializableTxPacket for ImmutableDeserializedPacket {
type DeserializeError = DeserializedPacketError;

fn from_packet(packet: Packet) -> Result<Self, Self::DeserializeError> {
ImmutableDeserializedPacket::new(packet)
}

fn build_sanitized_transaction(
&self,
votes_only: bool,
address_loader: impl AddressLoader,
reserved_account_keys: &HashSet<Pubkey>,
) -> Option<SanitizedTransaction> {
self.build_sanitized_transaction(votes_only, address_loader, reserved_account_keys)
}

fn original_packet(&self) -> &Packet {
&self.original_packet
}

fn transaction(&self) -> &SanitizedVersionedTransaction {
&self.transaction
}

fn message_hash(&self) -> &Hash {
&self.message_hash
}

fn is_simple_vote(&self) -> bool {
self.is_simple_vote
}

fn compute_unit_price(&self) -> u64 {
self.compute_unit_price
}

fn compute_unit_limit(&self) -> u64 {
u64::from(self.compute_unit_limit)
}
}
41 changes: 22 additions & 19 deletions prio-graph-scheduler/src/prio_graph_scheduler.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
use {
crate::scheduler_messages::{
ConsumeWork, FinishedConsumeWork, TransactionBatchId, TransactionId,
},
crate::transaction_priority_id::TransactionPriorityId,
crate::transaction_state::TransactionState,
crate::{
deserializable_packet::DeserializableTxPacket,
in_flight_tracker::InFlightTracker,
scheduler_error::SchedulerError,
scheduler_messages::{ConsumeWork, FinishedConsumeWork, TransactionBatchId, TransactionId},
thread_aware_account_locks::{ThreadAwareAccountLocks, ThreadId, ThreadSet},
transaction_state::SanitizedTransactionTTL,
transaction_priority_id::TransactionPriorityId,
transaction_state::{SanitizedTransactionTTL, TransactionState},
transaction_state_container::TransactionStateContainer,
},
crossbeam_channel::{Receiver, Sender, TryRecvError},
Expand All @@ -25,15 +23,16 @@ use {
},
};

pub struct PrioGraphScheduler {
pub struct PrioGraphScheduler<P: DeserializableTxPacket> {
in_flight_tracker: InFlightTracker,
account_locks: ThreadAwareAccountLocks,
consume_work_senders: Vec<Sender<ConsumeWork>>,
finished_consume_work_receiver: Receiver<FinishedConsumeWork>,
look_ahead_window_size: usize,
phantom: std::marker::PhantomData<P>,
}

impl PrioGraphScheduler {
impl<P: DeserializableTxPacket> PrioGraphScheduler<P> {
pub fn new(
consume_work_senders: Vec<Sender<ConsumeWork>>,
finished_consume_work_receiver: Receiver<FinishedConsumeWork>,
Expand All @@ -45,6 +44,7 @@ impl PrioGraphScheduler {
consume_work_senders,
finished_consume_work_receiver,
look_ahead_window_size: 2048,
phantom: std::marker::PhantomData,
}
}

Expand All @@ -66,7 +66,7 @@ impl PrioGraphScheduler {
/// not cause conflicts in the near future.
pub fn schedule(
&mut self,
container: &mut TransactionStateContainer,
container: &mut TransactionStateContainer<P>,
pre_graph_filter: impl Fn(&[&SanitizedTransaction], &mut [bool]),
pre_lock_filter: impl Fn(&SanitizedTransaction) -> bool,
) -> Result<SchedulingSummary, SchedulerError> {
Expand Down Expand Up @@ -102,7 +102,7 @@ impl PrioGraphScheduler {
let mut total_filter_time_us: u64 = 0;

let mut window_budget = self.look_ahead_window_size;
let mut chunked_pops = |container: &mut TransactionStateContainer,
let mut chunked_pops = |container: &mut TransactionStateContainer<P>,
prio_graph: &mut PrioGraph<_, _, _, _>,
window_budget: &mut usize| {
while *window_budget > 0 {
Expand Down Expand Up @@ -281,7 +281,7 @@ impl PrioGraphScheduler {
/// Returns (num_transactions, num_retryable_transactions) on success.
pub fn receive_completed(
&mut self,
container: &mut TransactionStateContainer,
container: &mut TransactionStateContainer<P>,
) -> Result<(usize, usize), SchedulerError> {
let mut total_num_transactions: usize = 0;
let mut total_num_retryable: usize = 0;
Expand All @@ -300,7 +300,7 @@ impl PrioGraphScheduler {
/// Returns `Ok((num_transactions, num_retryable))` if a batch was received, `Ok((0, 0))` if no batch was received.
fn try_receive_completed(
&mut self,
container: &mut TransactionStateContainer,
container: &mut TransactionStateContainer<P>,
) -> Result<(usize, usize), SchedulerError> {
match self.finished_consume_work_receiver.try_recv() {
Ok(FinishedConsumeWork {
Expand Down Expand Up @@ -535,8 +535,8 @@ enum TransactionSchedulingError {
UnschedulableConflicts,
}

fn try_schedule_transaction(
transaction_state: &mut TransactionState,
fn try_schedule_transaction<P: DeserializableTxPacket>(
transaction_state: &mut TransactionState<P>,
pre_lock_filter: impl Fn(&SanitizedTransaction) -> bool,
blocking_locks: &mut ReadWriteAccountSet,
account_locks: &mut ThreadAwareAccountLocks,
Expand Down Expand Up @@ -621,15 +621,17 @@ mod tests {
fn create_test_frame(
num_threads: usize,
) -> (
PrioGraphScheduler,
PrioGraphScheduler<ImmutableDeserializedPacket>,
Vec<Receiver<ConsumeWork>>,
Sender<FinishedConsumeWork>,
) {
let (consume_work_senders, consume_work_receivers) =
(0..num_threads).map(|_| unbounded()).unzip();
let (finished_consume_work_sender, finished_consume_work_receiver) = unbounded();
let scheduler =
PrioGraphScheduler::new(consume_work_senders, finished_consume_work_receiver);
let scheduler = PrioGraphScheduler::<ImmutableDeserializedPacket>::new(
consume_work_senders,
finished_consume_work_receiver,
);
(
scheduler,
consume_work_receivers,
Expand Down Expand Up @@ -666,8 +668,9 @@ mod tests {
u64,
),
>,
) -> TransactionStateContainer {
let mut container = TransactionStateContainer::with_capacity(10 * 1024);
) -> TransactionStateContainer<ImmutableDeserializedPacket> {
let mut container =
TransactionStateContainer::<ImmutableDeserializedPacket>::with_capacity(10 * 1024);
for (index, (from_keypair, to_pubkeys, lamports, compute_unit_price)) in
tx_infos.into_iter().enumerate()
{
Expand Down
Loading

0 comments on commit 32ce898

Please sign in to comment.