From 01c34a630b0bd6f5b9638328243fcf50eb9dd7e4 Mon Sep 17 00:00:00 2001 From: lewis Date: Mon, 28 Oct 2024 21:56:41 +0800 Subject: [PATCH 1/4] feat: refactor scheduler submodule --- .../mod.rs} | 2 ++ scheduler/src/impls/prio_graph_scheduler.rs | 1 - .../src/impls/prio_graph_scheduler/mod.rs | 22 +++++++++++++++++++ scheduler/src/scheduler.rs | 2 ++ 4 files changed, 26 insertions(+), 1 deletion(-) rename scheduler/src/impls/{no_lock_scheduler.rs => no_lock_scheduler/mod.rs} (94%) delete mode 100644 scheduler/src/impls/prio_graph_scheduler.rs create mode 100644 scheduler/src/impls/prio_graph_scheduler/mod.rs diff --git a/scheduler/src/impls/no_lock_scheduler.rs b/scheduler/src/impls/no_lock_scheduler/mod.rs similarity index 94% rename from scheduler/src/impls/no_lock_scheduler.rs rename to scheduler/src/impls/no_lock_scheduler/mod.rs index 808221f..c90f36a 100644 --- a/scheduler/src/impls/no_lock_scheduler.rs +++ b/scheduler/src/impls/no_lock_scheduler/mod.rs @@ -35,4 +35,6 @@ impl Scheduler for NoLockScheduler { self.task_senders[worker_id].send(batch).unwrap(); }); } + + fn receive_complete(&mut self, receipt: SchedulingBatchResult) {} } diff --git a/scheduler/src/impls/prio_graph_scheduler.rs b/scheduler/src/impls/prio_graph_scheduler.rs deleted file mode 100644 index 8b13789..0000000 --- a/scheduler/src/impls/prio_graph_scheduler.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/scheduler/src/impls/prio_graph_scheduler/mod.rs b/scheduler/src/impls/prio_graph_scheduler/mod.rs new file mode 100644 index 0000000..cb23897 --- /dev/null +++ b/scheduler/src/impls/prio_graph_scheduler/mod.rs @@ -0,0 +1,22 @@ +use crate::scheduler::Scheduler; +use crate::scheduler_messages::{SchedulingBatch, SchedulingBatchResult}; +use crossbeam_channel::{Receiver, Sender}; + +pub struct PrioGraphSchedulerWrapper {} + +impl Scheduler for PrioGraphSchedulerWrapper { + fn new( + schedule_task_senders: Vec>, + task_finished_receivers: Receiver, + ) -> Self { + todo!() + } + + fn schedule_batch(&mut self, txs: SchedulingBatch) { + todo!() + } + + fn receive_complete(&mut self, receipt: SchedulingBatchResult) { + todo!() + } +} diff --git a/scheduler/src/scheduler.rs b/scheduler/src/scheduler.rs index b375d1a..a15a7fe 100644 --- a/scheduler/src/scheduler.rs +++ b/scheduler/src/scheduler.rs @@ -18,4 +18,6 @@ pub trait Scheduler { ) -> Self; fn schedule_batch(&mut self, txs: SchedulingBatch); + + fn receive_complete(&mut self, receipt: SchedulingBatchResult); } From 0f04ac2984fd248e03ec19425ebf9bac61f58442 Mon Sep 17 00:00:00 2001 From: lewis Date: Tue, 29 Oct 2024 14:27:05 +0800 Subject: [PATCH 2/4] feat: simplify prio-graph scheduler --- .../prio_graph_scheduler/in_flight_tracker.rs | 123 +++ .../src/impls/prio_graph_scheduler/mod.rs | 12 + .../prio_graph_scheduler.rs | 910 ++++++++++++++++++ .../read_write_account_set.rs | 287 ++++++ .../prio_graph_scheduler/scheduler_error.rs | 9 + .../prio_graph_scheduler/scheduler_metrics.rs | 409 ++++++++ .../thread_aware_account_locks.rs | 742 ++++++++++++++ .../transaction_priority_id.rs | 69 ++ .../prio_graph_scheduler/transaction_state.rs | 323 +++++++ .../transaction_state_container.rs | 263 +++++ scheduler/src/scheduler_messages.rs | 9 + 11 files changed, 3156 insertions(+) create mode 100644 scheduler/src/impls/prio_graph_scheduler/in_flight_tracker.rs create mode 100644 scheduler/src/impls/prio_graph_scheduler/prio_graph_scheduler.rs create mode 100644 scheduler/src/impls/prio_graph_scheduler/read_write_account_set.rs create mode 100644 scheduler/src/impls/prio_graph_scheduler/scheduler_error.rs create mode 100644 scheduler/src/impls/prio_graph_scheduler/scheduler_metrics.rs create mode 100644 scheduler/src/impls/prio_graph_scheduler/thread_aware_account_locks.rs create mode 100644 scheduler/src/impls/prio_graph_scheduler/transaction_priority_id.rs create mode 100644 scheduler/src/impls/prio_graph_scheduler/transaction_state.rs create mode 100644 scheduler/src/impls/prio_graph_scheduler/transaction_state_container.rs diff --git a/scheduler/src/impls/prio_graph_scheduler/in_flight_tracker.rs b/scheduler/src/impls/prio_graph_scheduler/in_flight_tracker.rs new file mode 100644 index 0000000..13575b4 --- /dev/null +++ b/scheduler/src/impls/prio_graph_scheduler/in_flight_tracker.rs @@ -0,0 +1,123 @@ +use { + crate::id_generator::IdGenerator, + crate::impls::prio_graph_scheduler::thread_aware_account_locks::ThreadId, + crate::scheduler_messages::TransactionBatchId, std::collections::HashMap, +}; + +/// Tracks the number of transactions that are in flight for each thread. +pub struct InFlightTracker { + num_in_flight_per_thread: Vec, + cus_in_flight_per_thread: Vec, + batches: HashMap, + batch_id_generator: IdGenerator, +} + +struct BatchEntry { + thread_id: ThreadId, + num_transactions: usize, + total_cus: u64, +} + +impl InFlightTracker { + pub fn new(num_threads: usize) -> Self { + Self { + num_in_flight_per_thread: vec![0; num_threads], + cus_in_flight_per_thread: vec![0; num_threads], + batches: HashMap::new(), + batch_id_generator: IdGenerator::default(), + } + } + + /// Returns the number of transactions that are in flight for each thread. + pub fn num_in_flight_per_thread(&self) -> &[usize] { + &self.num_in_flight_per_thread + } + + /// Returns the number of cus that are in flight for each thread. + pub fn cus_in_flight_per_thread(&self) -> &[u64] { + &self.cus_in_flight_per_thread + } + + /// Tracks number of transactions and CUs in-flight for the `thread_id`. + /// Returns a `TransactionBatchId` that can be used to stop tracking the batch + /// when it is complete. + pub fn track_batch( + &mut self, + num_transactions: usize, + total_cus: u64, + thread_id: ThreadId, + ) -> TransactionBatchId { + let batch_id = self.batch_id_generator.next(); + self.num_in_flight_per_thread[thread_id] += num_transactions; + self.cus_in_flight_per_thread[thread_id] += total_cus; + self.batches.insert( + batch_id, + BatchEntry { + thread_id, + num_transactions, + total_cus, + }, + ); + + batch_id + } + + /// Stop tracking the batch with given `batch_id`. + /// Removes the number of transactions for the scheduled thread. + /// Returns the thread id that the batch was scheduled on. + /// + /// # Panics + /// Panics if the batch id does not exist in the tracker. + pub fn complete_batch(&mut self, batch_id: TransactionBatchId) -> ThreadId { + let Some(BatchEntry { + thread_id, + num_transactions, + total_cus, + }) = self.batches.remove(&batch_id) + else { + panic!("batch id {batch_id} is not being tracked"); + }; + self.num_in_flight_per_thread[thread_id] -= num_transactions; + self.cus_in_flight_per_thread[thread_id] -= total_cus; + + thread_id + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[should_panic(expected = "is not being tracked")] + fn test_in_flight_tracker_untracked_batch() { + let mut in_flight_tracker = InFlightTracker::new(2); + in_flight_tracker.complete_batch(TransactionBatchId::new(5)); + } + + #[test] + fn test_in_flight_tracker() { + let mut in_flight_tracker = InFlightTracker::new(2); + + // Add a batch with 2 transactions, 10 kCUs to thread 0. + let batch_id_0 = in_flight_tracker.track_batch(2, 10_000, 0); + assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[2, 0]); + assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[10_000, 0]); + + // Add a batch with 1 transaction, 15 kCUs to thread 1. + let batch_id_1 = in_flight_tracker.track_batch(1, 15_000, 1); + assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[2, 1]); + assert_eq!( + in_flight_tracker.cus_in_flight_per_thread(), + &[10_000, 15_000] + ); + + in_flight_tracker.complete_batch(batch_id_0); + assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[0, 1]); + assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[0, 15_000]); + + in_flight_tracker.complete_batch(batch_id_1); + assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[0, 0]); + assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[0, 0]); + } +} diff --git a/scheduler/src/impls/prio_graph_scheduler/mod.rs b/scheduler/src/impls/prio_graph_scheduler/mod.rs index cb23897..55e0179 100644 --- a/scheduler/src/impls/prio_graph_scheduler/mod.rs +++ b/scheduler/src/impls/prio_graph_scheduler/mod.rs @@ -1,7 +1,19 @@ +pub mod in_flight_tracker; +pub mod prio_graph_scheduler; +pub mod read_write_account_set; +pub mod scheduler_error; +pub mod scheduler_metrics; +pub mod thread_aware_account_locks; +pub mod transaction_priority_id; +pub mod transaction_state_container; +pub mod transaction_state; + use crate::scheduler::Scheduler; use crate::scheduler_messages::{SchedulingBatch, SchedulingBatchResult}; use crossbeam_channel::{Receiver, Sender}; +pub const TARGET_NUM_TRANSACTIONS_PER_BATCH: i32 = 128; + pub struct PrioGraphSchedulerWrapper {} impl Scheduler for PrioGraphSchedulerWrapper { diff --git a/scheduler/src/impls/prio_graph_scheduler/prio_graph_scheduler.rs b/scheduler/src/impls/prio_graph_scheduler/prio_graph_scheduler.rs new file mode 100644 index 0000000..c4991ed --- /dev/null +++ b/scheduler/src/impls/prio_graph_scheduler/prio_graph_scheduler.rs @@ -0,0 +1,910 @@ +// use { +// crate::{ +// impls::prio_graph_scheduler::in_flight_tracker::InFlightTracker, +// impls::prio_graph_scheduler::read_write_account_set::ReadWriteAccountSet, +// impls::prio_graph_scheduler::scheduler_error::SchedulerError, +// impls::prio_graph_scheduler::thread_aware_account_locks::{ +// ThreadAwareAccountLocks, ThreadId, ThreadSet, +// }, +// impls::prio_graph_scheduler::transaction_priority_id::TransactionPriorityId, +// impls::prio_graph_scheduler::TARGET_NUM_TRANSACTIONS_PER_BATCH, +// scheduler_messages::{ +// SchedulingBatch, SchedulingBatchResult, TransactionBatchId, TransactionId, +// }, +// transaction_state::{SanitizedTransactionTTL, TransactionState}, +// transaction_state_container::TransactionStateContainer, +// }, +// crossbeam_channel::{Receiver, Sender, TryRecvError}, +// itertools::izip, +// prio_graph::{AccessKind, PrioGraph}, +// solana_cost_model::block_cost_limits::MAX_BLOCK_UNITS, +// solana_measure::measure_us, +// solana_sdk::{pubkey::Pubkey, saturating_add_assign, transaction::SanitizedTransaction}, +// }; +// +// pub struct PrioGraphScheduler { +// in_flight_tracker: InFlightTracker, +// account_locks: ThreadAwareAccountLocks, +// consume_work_senders: Vec>, +// finished_consume_work_receiver: Receiver, +// look_ahead_window_size: usize, +// } +// +// impl PrioGraphScheduler { +// pub fn new( +// consume_work_senders: Vec>, +// finished_consume_work_receiver: Receiver, +// ) -> Self { +// let num_threads = consume_work_senders.len(); +// Self { +// in_flight_tracker: InFlightTracker::new(num_threads), +// account_locks: ThreadAwareAccountLocks::new(num_threads), +// consume_work_senders, +// finished_consume_work_receiver, +// look_ahead_window_size: 2048, +// } +// } +// +// /// Schedule transactions from the given `TransactionStateContainer` to be +// /// consumed by the worker threads. Returns summary of scheduling, or an +// /// error. +// /// `pre_graph_filter` is used to filter out transactions that should be +// /// skipped and dropped before insertion to the prio-graph. This fn should +// /// set `false` for transactions that should be dropped, and `true` +// /// otherwise. +// /// `pre_lock_filter` is used to filter out transactions after they have +// /// made it to the top of the prio-graph, and immediately before locks are +// /// checked and taken. This fn should return `true` for transactions that +// /// should be scheduled, and `false` otherwise. +// /// +// /// Uses a `PrioGraph` to perform look-ahead during the scheduling of transactions. +// /// This, combined with internal tracking of threads' in-flight transactions, allows +// /// for load-balancing while prioritizing scheduling transactions onto threads that will +// /// not cause conflicts in the near future. +// pub fn schedule( +// &mut self, +// container: &mut TransactionStateContainer, +// pre_graph_filter: impl Fn(&[&SanitizedTransaction], &mut [bool]), +// pre_lock_filter: impl Fn(&SanitizedTransaction) -> bool, +// ) -> Result { +// let num_threads = self.consume_work_senders.len(); +// let max_cu_per_thread = MAX_BLOCK_UNITS / num_threads as u64; +// +// let mut schedulable_threads = ThreadSet::any(num_threads); +// for thread_id in 0..num_threads { +// if self.in_flight_tracker.cus_in_flight_per_thread()[thread_id] >= max_cu_per_thread { +// schedulable_threads.remove(thread_id); +// } +// } +// if schedulable_threads.is_empty() { +// return Ok(SchedulingSummary { +// num_scheduled: 0, +// num_unschedulable: 0, +// num_filtered_out: 0, +// filter_time_us: 0, +// }); +// } +// +// let mut batches = Batches::new(num_threads); +// // Some transactions may be unschedulable due to multi-thread conflicts. +// // These transactions cannot be scheduled until some conflicting work is completed. +// // However, the scheduler should not allow other transactions that conflict with +// // these transactions to be scheduled before them. +// let mut unschedulable_ids = Vec::new(); +// let mut blocking_locks = ReadWriteAccountSet::default(); +// let mut prio_graph = PrioGraph::new(|id: &TransactionPriorityId, _graph_node| *id); +// +// // Track metrics on filter. +// let mut num_filtered_out: usize = 0; +// let mut total_filter_time_us: u64 = 0; +// +// let mut window_budget = self.look_ahead_window_size; +// let mut chunked_pops = |container: &mut TransactionStateContainer

, +// prio_graph: &mut PrioGraph<_, _, _, _>, +// window_budget: &mut usize| { +// while *window_budget > 0 { +// const MAX_FILTER_CHUNK_SIZE: usize = 128; +// let mut filter_array = [true; MAX_FILTER_CHUNK_SIZE]; +// let mut ids = Vec::with_capacity(MAX_FILTER_CHUNK_SIZE); +// let mut txs = Vec::with_capacity(MAX_FILTER_CHUNK_SIZE); +// +// let chunk_size = (*window_budget).min(MAX_FILTER_CHUNK_SIZE); +// for _ in 0..chunk_size { +// if let Some(id) = container.pop() { +// ids.push(id); +// } else { +// break; +// } +// } +// *window_budget = window_budget.saturating_sub(chunk_size); +// +// ids.iter().for_each(|id| { +// let transaction = container.get_transaction_ttl(&id.id).unwrap(); +// txs.push(&transaction.transaction); +// }); +// +// let (_, filter_us) = +// measure_us!(pre_graph_filter(&txs, &mut filter_array[..chunk_size])); +// saturating_add_assign!(total_filter_time_us, filter_us); +// +// for (id, filter_result) in ids.iter().zip(&filter_array[..chunk_size]) { +// if *filter_result { +// let transaction = container.get_transaction_ttl(&id.id).unwrap(); +// prio_graph.insert_transaction( +// *id, +// Self::get_transaction_account_access(transaction), +// ); +// } else { +// saturating_add_assign!(num_filtered_out, 1); +// container.remove_by_id(&id.id); +// } +// } +// +// if ids.len() != chunk_size { +// break; +// } +// } +// }; +// +// // Create the initial look-ahead window. +// // Check transactions against filter, remove from container if it fails. +// chunked_pops(container, &mut prio_graph, &mut window_budget); +// +// let mut unblock_this_batch = +// Vec::with_capacity(self.consume_work_senders.len() * TARGET_NUM_TRANSACTIONS_PER_BATCH); +// const MAX_TRANSACTIONS_PER_SCHEDULING_PASS: usize = 100_000; +// let mut num_scheduled: usize = 0; +// let mut num_sent: usize = 0; +// let mut num_unschedulable: usize = 0; +// while num_scheduled < MAX_TRANSACTIONS_PER_SCHEDULING_PASS { +// // If nothing is in the main-queue of the `PrioGraph` then there's nothing left to schedule. +// if prio_graph.is_empty() { +// break; +// } +// +// while let Some(id) = prio_graph.pop() { +// unblock_this_batch.push(id); +// +// // Should always be in the container, during initial testing phase panic. +// // Later, we can replace with a continue in case this does happen. +// let Some(transaction_state) = container.get_mut_transaction_state(&id.id) else { +// panic!("transaction state must exist") +// }; +// +// let maybe_schedule_info = try_schedule_transaction( +// transaction_state, +// &pre_lock_filter, +// &mut blocking_locks, +// &mut self.account_locks, +// num_threads, +// |thread_set| { +// Self::select_thread( +// thread_set, +// &batches.total_cus, +// self.in_flight_tracker.cus_in_flight_per_thread(), +// &batches.transactions, +// self.in_flight_tracker.num_in_flight_per_thread(), +// ) +// }, +// ); +// +// match maybe_schedule_info { +// Err(TransactionSchedulingError::Filtered) => { +// container.remove_by_id(&id.id); +// } +// Err(TransactionSchedulingError::UnschedulableConflicts) => { +// unschedulable_ids.push(id); +// saturating_add_assign!(num_unschedulable, 1); +// } +// Ok(TransactionSchedulingInfo { +// thread_id, +// transaction, +// max_age, +// cost, +// }) => { +// saturating_add_assign!(num_scheduled, 1); +// batches.transactions[thread_id].push(transaction); +// batches.ids[thread_id].push(id.id); +// batches.max_ages[thread_id].push(max_age); +// saturating_add_assign!(batches.total_cus[thread_id], cost); +// +// // If target batch size is reached, send only this batch. +// if batches.ids[thread_id].len() >= TARGET_NUM_TRANSACTIONS_PER_BATCH { +// saturating_add_assign!( +// num_sent, +// self.send_batch(&mut batches, thread_id)? +// ); +// } +// +// // if the thread is at max_cu_per_thread, remove it from the schedulable threads +// // if there are no more schedulable threads, stop scheduling. +// if self.in_flight_tracker.cus_in_flight_per_thread()[thread_id] +// + batches.total_cus[thread_id] +// >= max_cu_per_thread +// { +// schedulable_threads.remove(thread_id); +// if schedulable_threads.is_empty() { +// break; +// } +// } +// +// if num_scheduled >= MAX_TRANSACTIONS_PER_SCHEDULING_PASS { +// break; +// } +// } +// } +// } +// +// // Send all non-empty batches +// saturating_add_assign!(num_sent, self.send_batches(&mut batches)?); +// +// // Refresh window budget and do chunked pops +// saturating_add_assign!(window_budget, unblock_this_batch.len()); +// chunked_pops(container, &mut prio_graph, &mut window_budget); +// +// // Unblock all transactions that were blocked by the transactions that were just sent. +// for id in unblock_this_batch.drain(..) { +// prio_graph.unblock(&id); +// } +// } +// +// // Send batches for any remaining transactions +// saturating_add_assign!(num_sent, self.send_batches(&mut batches)?); +// +// // Push unschedulable ids back into the container +// for id in unschedulable_ids { +// container.push_id_into_queue(id); +// } +// +// // Push remaining transactions back into the container +// while let Some((id, _)) = prio_graph.pop_and_unblock() { +// container.push_id_into_queue(id); +// } +// +// assert_eq!( +// num_scheduled, num_sent, +// "number of scheduled and sent transactions must match" +// ); +// +// Ok(SchedulingSummary { +// num_scheduled, +// num_unschedulable, +// num_filtered_out, +// filter_time_us: total_filter_time_us, +// }) +// } +// +// /// Receive completed batches of transactions without blocking. +// /// Returns (num_transactions, num_retryable_transactions) on success. +// pub fn receive_completed( +// &mut self, +// container: &mut TransactionStateContainer

, +// ) -> Result<(usize, usize), SchedulerError> { +// let mut total_num_transactions: usize = 0; +// let mut total_num_retryable: usize = 0; +// loop { +// let (num_transactions, num_retryable) = self.try_receive_completed(container)?; +// if num_transactions == 0 { +// break; +// } +// saturating_add_assign!(total_num_transactions, num_transactions); +// saturating_add_assign!(total_num_retryable, num_retryable); +// } +// Ok((total_num_transactions, total_num_retryable)) +// } +// +// /// Receive completed batches of transactions. +// /// Returns `Ok((num_transactions, num_retryable))` if a batch was received, `Ok((0, 0))` if no batch was received. +// fn try_receive_completed( +// &mut self, +// container: &mut TransactionStateContainer

, +// ) -> Result<(usize, usize), SchedulerError> { +// match self.finished_consume_work_receiver.try_recv() { +// Ok(FinishedConsumeWork { +// work: +// ConsumeWork { +// batch_id, +// ids, +// transactions, +// max_ages, +// }, +// retryable_indexes, +// }) => { +// let num_transactions = ids.len(); +// let num_retryable = retryable_indexes.len(); +// +// // Free the locks +// self.complete_batch(batch_id, &transactions); +// +// // Retryable transactions should be inserted back into the container +// let mut retryable_iter = retryable_indexes.into_iter().peekable(); +// for (index, (id, transaction, max_age)) in +// izip!(ids, transactions, max_ages).enumerate() +// { +// if let Some(retryable_index) = retryable_iter.peek() { +// if *retryable_index == index { +// container.retry_transaction( +// id, +// SanitizedTransactionTTL { +// transaction, +// max_age, +// }, +// ); +// retryable_iter.next(); +// continue; +// } +// } +// container.remove_by_id(&id); +// } +// +// Ok((num_transactions, num_retryable)) +// } +// Err(TryRecvError::Empty) => Ok((0, 0)), +// Err(TryRecvError::Disconnected) => Err(SchedulerError::DisconnectedRecvChannel( +// "finished consume work", +// )), +// } +// } +// +// /// Mark a given `TransactionBatchId` as completed. +// /// This will update the internal tracking, including account locks. +// fn complete_batch( +// &mut self, +// batch_id: TransactionBatchId, +// transactions: &[SanitizedTransaction], +// ) { +// let thread_id = self.in_flight_tracker.complete_batch(batch_id); +// for transaction in transactions { +// let message = transaction.message(); +// let account_keys = message.account_keys(); +// let write_account_locks = account_keys +// .iter() +// .enumerate() +// .filter_map(|(index, key)| message.is_writable(index).then_some(key)); +// let read_account_locks = account_keys +// .iter() +// .enumerate() +// .filter_map(|(index, key)| (!message.is_writable(index)).then_some(key)); +// self.account_locks +// .unlock_accounts(write_account_locks, read_account_locks, thread_id); +// } +// } +// +// /// Send all batches of transactions to the worker threads. +// /// Returns the number of transactions sent. +// fn send_batches(&mut self, batches: &mut Batches) -> Result { +// (0..self.consume_work_senders.len()) +// .map(|thread_index| self.send_batch(batches, thread_index)) +// .sum() +// } +// +// /// Send a batch of transactions to the given thread's `ConsumeWork` channel. +// /// Returns the number of transactions sent. +// fn send_batch( +// &mut self, +// batches: &mut Batches, +// thread_index: usize, +// ) -> Result { +// if batches.ids[thread_index].is_empty() { +// return Ok(0); +// } +// +// let (ids, transactions, max_ages, total_cus) = batches.take_batch(thread_index); +// +// let batch_id = self +// .in_flight_tracker +// .track_batch(ids.len(), total_cus, thread_index); +// +// let num_scheduled = ids.len(); +// let work = ConsumeWork { +// batch_id, +// ids, +// transactions, +// max_ages, +// }; +// self.consume_work_senders[thread_index] +// .send(work) +// .map_err(|_| SchedulerError::DisconnectedSendChannel("consume work sender"))?; +// +// Ok(num_scheduled) +// } +// +// /// Given the schedulable `thread_set`, select the thread with the least amount +// /// of work queued up. +// /// Currently, "work" is just defined as the number of transactions. +// /// +// /// If the `chain_thread` is available, this thread will be selected, regardless of +// /// load-balancing. +// /// +// /// Panics if the `thread_set` is empty. This should never happen, see comment +// /// on `ThreadAwareAccountLocks::try_lock_accounts`. +// fn select_thread( +// thread_set: ThreadSet, +// batch_cus_per_thread: &[u64], +// in_flight_cus_per_thread: &[u64], +// batches_per_thread: &[Vec], +// in_flight_per_thread: &[usize], +// ) -> ThreadId { +// thread_set +// .contained_threads_iter() +// .map(|thread_id| { +// ( +// thread_id, +// batch_cus_per_thread[thread_id] + in_flight_cus_per_thread[thread_id], +// batches_per_thread[thread_id].len() + in_flight_per_thread[thread_id], +// ) +// }) +// .min_by(|a, b| a.1.cmp(&b.1).then_with(|| a.2.cmp(&b.2))) +// .map(|(thread_id, _, _)| thread_id) +// .unwrap() +// } +// +// /// Gets accessed accounts (resources) for use in `PrioGraph`. +// fn get_transaction_account_access( +// transaction: &SanitizedTransactionTTL, +// ) -> impl Iterator + '_ { +// let message = transaction.transaction.message(); +// message +// .account_keys() +// .iter() +// .enumerate() +// .map(|(index, key)| { +// if message.is_writable(index) { +// (*key, AccessKind::Write) +// } else { +// (*key, AccessKind::Read) +// } +// }) +// } +// } +// +// /// Metrics from scheduling transactions. +// #[derive(Debug, PartialEq, Eq)] +// pub struct SchedulingSummary { +// /// Number of transactions scheduled. +// pub num_scheduled: usize, +// /// Number of transactions that were not scheduled due to conflicts. +// pub num_unschedulable: usize, +// /// Number of transactions that were dropped due to filter. +// pub num_filtered_out: usize, +// /// Time spent filtering transactions +// pub filter_time_us: u64, +// } +// +// struct Batches { +// ids: Vec>, +// transactions: Vec>, +// max_ages: Vec>, +// total_cus: Vec, +// } +// +// impl Batches { +// fn new(num_threads: usize) -> Self { +// Self { +// ids: vec![Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH); num_threads], +// transactions: vec![Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH); num_threads], +// max_ages: vec![Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH); num_threads], +// total_cus: vec![0; num_threads], +// } +// } +// +// fn take_batch( +// &mut self, +// thread_id: ThreadId, +// ) -> ( +// Vec, +// Vec, +// Vec, +// u64, +// ) { +// ( +// core::mem::replace( +// &mut self.ids[thread_id], +// Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH), +// ), +// core::mem::replace( +// &mut self.transactions[thread_id], +// Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH), +// ), +// core::mem::replace( +// &mut self.max_ages[thread_id], +// Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH), +// ), +// core::mem::replace(&mut self.total_cus[thread_id], 0), +// ) +// } +// } +// +// /// A transaction has been scheduled to a thread. +// struct TransactionSchedulingInfo { +// thread_id: ThreadId, +// transaction: SanitizedTransaction, +// max_age: MaxAge, +// cost: u64, +// } +// +// /// Error type for reasons a transaction could not be scheduled. +// enum TransactionSchedulingError { +// /// Transaction was filtered out before locking. +// Filtered, +// /// Transaction cannot be scheduled due to conflicts, or +// /// higher priority conflicting transactions are unschedulable. +// UnschedulableConflicts, +// } +// +// fn try_schedule_transaction( +// transaction_state: &mut TransactionState

, +// pre_lock_filter: impl Fn(&SanitizedTransaction) -> bool, +// blocking_locks: &mut ReadWriteAccountSet, +// account_locks: &mut ThreadAwareAccountLocks, +// num_threads: usize, +// thread_selector: impl Fn(ThreadSet) -> ThreadId, +// ) -> Result { +// let transaction = &transaction_state.transaction_ttl().transaction; +// if !pre_lock_filter(transaction) { +// return Err(TransactionSchedulingError::Filtered); +// } +// +// // Check if this transaction conflicts with any blocked transactions +// let message = transaction.message(); +// if !blocking_locks.check_locks(message) { +// blocking_locks.take_locks(message); +// return Err(TransactionSchedulingError::UnschedulableConflicts); +// } +// +// // Schedule the transaction if it can be. +// let message = transaction.message(); +// let account_keys = message.account_keys(); +// let write_account_locks = account_keys +// .iter() +// .enumerate() +// .filter_map(|(index, key)| message.is_writable(index).then_some(key)) +// .collect::>(); +// let read_account_locks = account_keys +// .iter() +// .enumerate() +// .filter_map(|(index, key)| (!message.is_writable(index)).then_some(key)) +// .collect::>(); +// +// let Some(thread_id) = account_locks.try_lock_accounts( +// write_account_locks.into_iter(), +// read_account_locks.into_iter(), +// ThreadSet::any(num_threads), +// thread_selector, +// ) else { +// blocking_locks.take_locks(message); +// return Err(TransactionSchedulingError::UnschedulableConflicts); +// }; +// +// let sanitized_transaction_ttl = transaction_state.transition_to_pending(); +// let cost = transaction_state.cost(); +// +// Ok(TransactionSchedulingInfo { +// thread_id, +// transaction: sanitized_transaction_ttl.transaction, +// max_age: sanitized_transaction_ttl.max_age, +// cost, +// }) +// } +// +// #[cfg(test)] +// mod tests { +// use { +// super::*, +// crate::tests::MockImmutableDeserializedPacket, +// crate::TARGET_NUM_TRANSACTIONS_PER_BATCH, +// crossbeam_channel::{unbounded, Receiver}, +// itertools::Itertools, +// solana_sdk::{ +// clock::Slot, compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, +// packet::Packet, pubkey::Pubkey, signature::Keypair, signer::Signer, system_instruction, +// transaction::Transaction, +// }, +// std::{borrow::Borrow, sync::Arc}, +// }; +// +// macro_rules! txid { +// ($value:expr) => { +// TransactionId::new($value) +// }; +// } +// +// macro_rules! txids { +// ([$($element:expr),*]) => { +// vec![ $(txid!($element)),* ] +// }; +// } +// +// fn create_test_frame( +// num_threads: usize, +// ) -> ( +// PrioGraphScheduler, +// Vec>, +// Sender, +// ) { +// let (consume_work_senders, consume_work_receivers) = +// (0..num_threads).map(|_| unbounded()).unzip(); +// let (finished_consume_work_sender, finished_consume_work_receiver) = unbounded(); +// let scheduler = PrioGraphScheduler::::new( +// consume_work_senders, +// finished_consume_work_receiver, +// ); +// ( +// scheduler, +// consume_work_receivers, +// finished_consume_work_sender, +// ) +// } +// +// fn prioritized_tranfers( +// from_keypair: &Keypair, +// to_pubkeys: impl IntoIterator>, +// lamports: u64, +// priority: u64, +// ) -> SanitizedTransaction { +// let to_pubkeys_lamports = to_pubkeys +// .into_iter() +// .map(|pubkey| *pubkey.borrow()) +// .zip(std::iter::repeat(lamports)) +// .collect_vec(); +// let mut ixs = +// system_instruction::transfer_many(&from_keypair.pubkey(), &to_pubkeys_lamports); +// let prioritization = ComputeBudgetInstruction::set_compute_unit_price(priority); +// ixs.push(prioritization); +// let message = Message::new(&ixs, Some(&from_keypair.pubkey())); +// let tx = Transaction::new(&[from_keypair], message, Hash::default()); +// SanitizedTransaction::from_transaction_for_tests(tx) +// } +// +// fn create_container( +// tx_infos: impl IntoIterator< +// Item = ( +// impl Borrow, +// impl IntoIterator>, +// u64, +// u64, +// ), +// >, +// ) -> TransactionStateContainer { +// let mut container = +// TransactionStateContainer::::with_capacity(10 * 1024); +// for (index, (from_keypair, to_pubkeys, lamports, compute_unit_price)) in +// tx_infos.into_iter().enumerate() +// { +// let id = TransactionId::new(index as u64); +// let transaction = prioritized_tranfers( +// from_keypair.borrow(), +// to_pubkeys, +// lamports, +// compute_unit_price, +// ); +// let packet = Arc::new( +// MockImmutableDeserializedPacket::new( +// Packet::from_data(None, transaction.to_versioned_transaction()).unwrap(), +// ) +// .unwrap(), +// ); +// let transaction_ttl = SanitizedTransactionTTL { +// transaction, +// max_age: MaxAge { +// epoch_invalidation_slot: Slot::MAX, +// alt_invalidation_slot: Slot::MAX, +// }, +// }; +// const TEST_TRANSACTION_COST: u64 = 5000; +// container.insert_new_transaction( +// id, +// transaction_ttl, +// packet, +// compute_unit_price, +// TEST_TRANSACTION_COST, +// ); +// } +// +// container +// } +// +// fn collect_work( +// receiver: &Receiver, +// ) -> (Vec, Vec>) { +// receiver +// .try_iter() +// .map(|work| { +// let ids = work.ids.clone(); +// (work, ids) +// }) +// .unzip() +// } +// +// fn test_pre_graph_filter(_txs: &[&SanitizedTransaction], results: &mut [bool]) { +// results.fill(true); +// } +// +// fn test_pre_lock_filter(_tx: &SanitizedTransaction) -> bool { +// true +// } +// +// #[test] +// fn test_schedule_disconnected_channel() { +// let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); +// let mut container = create_container([(&Keypair::new(), &[Pubkey::new_unique()], 1, 1)]); +// +// drop(work_receivers); // explicitly drop receivers +// assert_matches!( +// scheduler.schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter), +// Err(SchedulerError::DisconnectedSendChannel(_)) +// ); +// } +// +// #[test] +// fn test_schedule_single_threaded_no_conflicts() { +// let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); +// let mut container = create_container([ +// (&Keypair::new(), &[Pubkey::new_unique()], 1, 1), +// (&Keypair::new(), &[Pubkey::new_unique()], 2, 2), +// ]); +// +// let scheduling_summary = scheduler +// .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) +// .unwrap(); +// assert_eq!(scheduling_summary.num_scheduled, 2); +// assert_eq!(scheduling_summary.num_unschedulable, 0); +// assert_eq!(collect_work(&work_receivers[0]).1, vec![txids!([1, 0])]); +// } +// +// #[test] +// fn test_schedule_single_threaded_conflict() { +// let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); +// let pubkey = Pubkey::new_unique(); +// let mut container = create_container([ +// (&Keypair::new(), &[pubkey], 1, 1), +// (&Keypair::new(), &[pubkey], 1, 2), +// ]); +// +// let scheduling_summary = scheduler +// .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) +// .unwrap(); +// assert_eq!(scheduling_summary.num_scheduled, 2); +// assert_eq!(scheduling_summary.num_unschedulable, 0); +// assert_eq!( +// collect_work(&work_receivers[0]).1, +// vec![txids!([1]), txids!([0])] +// ); +// } +// +// #[test] +// fn test_schedule_consume_single_threaded_multi_batch() { +// let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); +// let mut container = create_container( +// (0..4 * TARGET_NUM_TRANSACTIONS_PER_BATCH) +// .map(|i| (Keypair::new(), [Pubkey::new_unique()], i as u64, 1)), +// ); +// +// // expect 4 full batches to be scheduled +// let scheduling_summary = scheduler +// .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) +// .unwrap(); +// assert_eq!( +// scheduling_summary.num_scheduled, +// 4 * TARGET_NUM_TRANSACTIONS_PER_BATCH +// ); +// assert_eq!(scheduling_summary.num_unschedulable, 0); +// +// let thread0_work_counts: Vec<_> = work_receivers[0] +// .try_iter() +// .map(|work| work.ids.len()) +// .collect(); +// assert_eq!(thread0_work_counts, [TARGET_NUM_TRANSACTIONS_PER_BATCH; 4]); +// } +// +// #[test] +// fn test_schedule_simple_thread_selection() { +// let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(2); +// let mut container = +// create_container((0..4).map(|i| (Keypair::new(), [Pubkey::new_unique()], 1, i))); +// +// let scheduling_summary = scheduler +// .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) +// .unwrap(); +// assert_eq!(scheduling_summary.num_scheduled, 4); +// assert_eq!(scheduling_summary.num_unschedulable, 0); +// assert_eq!(collect_work(&work_receivers[0]).1, [txids!([3, 1])]); +// assert_eq!(collect_work(&work_receivers[1]).1, [txids!([2, 0])]); +// } +// +// #[test] +// fn test_schedule_priority_guard() { +// let (mut scheduler, work_receivers, finished_work_sender) = create_test_frame(2); +// // intentionally shorten the look-ahead window to cause unschedulable conflicts +// scheduler.look_ahead_window_size = 2; +// +// let accounts = (0..8).map(|_| Keypair::new()).collect_vec(); +// let mut container = create_container([ +// (&accounts[0], &[accounts[1].pubkey()], 1, 6), +// (&accounts[2], &[accounts[3].pubkey()], 1, 5), +// (&accounts[4], &[accounts[5].pubkey()], 1, 4), +// (&accounts[6], &[accounts[7].pubkey()], 1, 3), +// (&accounts[1], &[accounts[2].pubkey()], 1, 2), +// (&accounts[2], &[accounts[3].pubkey()], 1, 1), +// ]); +// +// // The look-ahead window is intentionally shortened, high priority transactions +// // [0, 1, 2, 3] do not conflict, and are scheduled onto threads in a +// // round-robin fashion. This leads to transaction [4] being unschedulable due +// // to conflicts with [0] and [1], which were scheduled to different threads. +// // Transaction [5] is technically schedulable, onto thread 1 since it only +// // conflicts with transaction [1]. However, [5] will not be scheduled because +// // it conflicts with a higher-priority transaction [4] that is unschedulable. +// // The full prio-graph can be visualized as: +// // [0] \ +// // -> [4] -> [5] +// // [1] / ------/ +// // [2] +// // [3] +// // Because the look-ahead window is shortened to a size of 4, the scheduler does +// // not have knowledge of the joining at transaction [4] until after [0] and [1] +// // have been scheduled. +// let scheduling_summary = scheduler +// .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) +// .unwrap(); +// assert_eq!(scheduling_summary.num_scheduled, 4); +// assert_eq!(scheduling_summary.num_unschedulable, 2); +// let (thread_0_work, thread_0_ids) = collect_work(&work_receivers[0]); +// assert_eq!(thread_0_ids, [txids!([0]), txids!([2])]); +// assert_eq!( +// collect_work(&work_receivers[1]).1, +// [txids!([1]), txids!([3])] +// ); +// +// // Cannot schedule even on next pass because of lock conflicts +// let scheduling_summary = scheduler +// .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) +// .unwrap(); +// assert_eq!(scheduling_summary.num_scheduled, 0); +// assert_eq!(scheduling_summary.num_unschedulable, 2); +// +// // Complete batch on thread 0. Remaining txs can be scheduled onto thread 1 +// finished_work_sender +// .send(FinishedConsumeWork { +// work: thread_0_work.into_iter().next().unwrap(), +// retryable_indexes: vec![], +// }) +// .unwrap(); +// scheduler.receive_completed(&mut container).unwrap(); +// let scheduling_summary = scheduler +// .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) +// .unwrap(); +// assert_eq!(scheduling_summary.num_scheduled, 2); +// assert_eq!(scheduling_summary.num_unschedulable, 0); +// +// assert_eq!( +// collect_work(&work_receivers[1]).1, +// [txids!([4]), txids!([5])] +// ); +// } +// +// #[test] +// fn test_schedule_pre_lock_filter() { +// let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); +// let pubkey = Pubkey::new_unique(); +// let keypair = Keypair::new(); +// let mut container = create_container([ +// (&Keypair::new(), &[pubkey], 1, 1), +// (&keypair, &[pubkey], 1, 2), +// (&Keypair::new(), &[pubkey], 1, 3), +// ]); +// +// // 2nd transaction should be filtered out and dropped before locking. +// let pre_lock_filter = +// |tx: &SanitizedTransaction| tx.message().fee_payer() != &keypair.pubkey(); +// let scheduling_summary = scheduler +// .schedule(&mut container, test_pre_graph_filter, pre_lock_filter) +// .unwrap(); +// assert_eq!(scheduling_summary.num_scheduled, 2); +// assert_eq!(scheduling_summary.num_unschedulable, 0); +// assert_eq!( +// collect_work(&work_receivers[0]).1, +// vec![txids!([2]), txids!([0])] +// ); +// } +// } diff --git a/scheduler/src/impls/prio_graph_scheduler/read_write_account_set.rs b/scheduler/src/impls/prio_graph_scheduler/read_write_account_set.rs new file mode 100644 index 0000000..4e23919 --- /dev/null +++ b/scheduler/src/impls/prio_graph_scheduler/read_write_account_set.rs @@ -0,0 +1,287 @@ +use { + ahash::AHashSet, + solana_sdk::{message::SanitizedMessage, pubkey::Pubkey}, +}; + +/// Wrapper struct to accumulate locks for a batch of transactions. +#[derive(Debug, Default)] +pub struct ReadWriteAccountSet { + /// Set of accounts that are locked for read + read_set: AHashSet, + /// Set of accounts that are locked for write + write_set: AHashSet, +} + +impl ReadWriteAccountSet { + /// Returns true if all account locks were available and false otherwise. + pub fn check_locks(&self, message: &SanitizedMessage) -> bool { + message + .account_keys() + .iter() + .enumerate() + .all(|(index, pubkey)| { + if message.is_writable(index) { + self.can_write(pubkey) + } else { + self.can_read(pubkey) + } + }) + } + + /// Add all account locks. + /// Returns true if all account locks were available and false otherwise. + pub fn take_locks(&mut self, message: &SanitizedMessage) -> bool { + message + .account_keys() + .iter() + .enumerate() + .fold(true, |all_available, (index, pubkey)| { + if message.is_writable(index) { + all_available & self.add_write(pubkey) + } else { + all_available & self.add_read(pubkey) + } + }) + } + + /// Clears the read and write sets + #[allow(dead_code)] + pub fn clear(&mut self) { + self.read_set.clear(); + self.write_set.clear(); + } + + /// Check if an account can be read-locked + fn can_read(&self, pubkey: &Pubkey) -> bool { + !self.write_set.contains(pubkey) + } + + /// Check if an account can be write-locked + fn can_write(&self, pubkey: &Pubkey) -> bool { + !self.write_set.contains(pubkey) && !self.read_set.contains(pubkey) + } + + /// Add an account to the read-set. + /// Returns true if the lock was available. + fn add_read(&mut self, pubkey: &Pubkey) -> bool { + let can_read = self.can_read(pubkey); + self.read_set.insert(*pubkey); + + can_read + } + + /// Add an account to the write-set. + /// Returns true if the lock was available. + fn add_write(&mut self, pubkey: &Pubkey) -> bool { + let can_write = self.can_write(pubkey); + self.write_set.insert(*pubkey); + + can_write + } +} + +#[cfg(test)] +mod tests { + use { + super::ReadWriteAccountSet, + solana_ledger::genesis_utils::GenesisConfigInfo, + solana_runtime::{bank::Bank, bank_forks::BankForks, genesis_utils::create_genesis_config}, + solana_sdk::{ + account::AccountSharedData, + address_lookup_table::{ + self, + state::{AddressLookupTable, LookupTableMeta}, + }, + hash::Hash, + message::{ + v0::{self, MessageAddressTableLookup}, + MessageHeader, VersionedMessage, + }, + pubkey::Pubkey, + signature::Keypair, + signer::Signer, + transaction::{MessageHash, SanitizedTransaction, VersionedTransaction}, + }, + std::{ + borrow::Cow, + sync::{Arc, RwLock}, + }, + }; + + fn create_test_versioned_message( + write_keys: &[Pubkey], + read_keys: &[Pubkey], + address_table_lookups: Vec, + ) -> VersionedMessage { + VersionedMessage::V0(v0::Message { + header: MessageHeader { + num_required_signatures: write_keys.len() as u8, + num_readonly_signed_accounts: 0, + num_readonly_unsigned_accounts: read_keys.len() as u8, + }, + recent_blockhash: Hash::default(), + account_keys: write_keys.iter().chain(read_keys.iter()).copied().collect(), + address_table_lookups, + instructions: vec![], + }) + } + + fn create_test_sanitized_transaction( + write_keypair: &Keypair, + read_keys: &[Pubkey], + address_table_lookups: Vec, + bank: &Bank, + ) -> SanitizedTransaction { + let message = create_test_versioned_message( + &[write_keypair.pubkey()], + read_keys, + address_table_lookups, + ); + SanitizedTransaction::try_create( + VersionedTransaction::try_new(message, &[write_keypair]).unwrap(), + MessageHash::Compute, + Some(false), + bank, + bank.get_reserved_account_keys(), + ) + .unwrap() + } + + fn create_test_address_lookup_table( + bank: Arc, + num_addresses: usize, + ) -> (Arc, Pubkey) { + let mut addresses = Vec::with_capacity(num_addresses); + addresses.resize_with(num_addresses, Pubkey::new_unique); + let address_lookup_table = AddressLookupTable { + meta: LookupTableMeta { + authority: None, + ..LookupTableMeta::default() + }, + addresses: Cow::Owned(addresses), + }; + + let address_table_key = Pubkey::new_unique(); + let data = address_lookup_table.serialize_for_tests().unwrap(); + let mut account = + AccountSharedData::new(1, data.len(), &address_lookup_table::program::id()); + account.set_data(data); + bank.store_account(&address_table_key, &account); + + let slot = bank.slot() + 1; + ( + Arc::new(Bank::new_from_parent(bank, &Pubkey::new_unique(), slot)), + address_table_key, + ) + } + + fn create_test_bank() -> (Arc, Arc>) { + let GenesisConfigInfo { genesis_config, .. } = create_genesis_config(10_000); + Bank::new_no_wallclock_throttle_for_tests(&genesis_config) + } + + // Helper function (could potentially use test_case in future). + // conflict_index = 0 means write lock conflict with static key + // conflict_index = 1 means read lock conflict with static key + // conflict_index = 2 means write lock conflict with address table key + // conflict_index = 3 means read lock conflict with address table key + fn test_check_and_take_locks(conflict_index: usize, add_write: bool, expectation: bool) { + let (bank, _bank_forks) = create_test_bank(); + let (bank, table_address) = create_test_address_lookup_table(bank, 2); + let tx = create_test_sanitized_transaction( + &Keypair::new(), + &[Pubkey::new_unique()], + vec![MessageAddressTableLookup { + account_key: table_address, + writable_indexes: vec![0], + readonly_indexes: vec![1], + }], + &bank, + ); + let message = tx.message(); + + let mut account_locks = ReadWriteAccountSet::default(); + + let conflict_key = message.account_keys().get(conflict_index).unwrap(); + if add_write { + account_locks.add_write(conflict_key); + } else { + account_locks.add_read(conflict_key); + } + assert_eq!(expectation, account_locks.check_locks(message)); + assert_eq!(expectation, account_locks.take_locks(message)); + } + + #[test] + fn test_check_and_take_locks_write_write_conflict() { + test_check_and_take_locks(0, true, false); // static key conflict + test_check_and_take_locks(2, true, false); // lookup key conflict + } + + #[test] + fn test_check_and_take_locks_read_write_conflict() { + test_check_and_take_locks(0, false, false); // static key conflict + test_check_and_take_locks(2, false, false); // lookup key conflict + } + + #[test] + fn test_check_and_take_locks_write_read_conflict() { + test_check_and_take_locks(1, true, false); // static key conflict + test_check_and_take_locks(3, true, false); // lookup key conflict + } + + #[test] + fn test_check_and_take_locks_read_read_non_conflict() { + test_check_and_take_locks(1, false, true); // static key conflict + test_check_and_take_locks(3, false, true); // lookup key conflict + } + + #[test] + pub fn test_write_write_conflict() { + let mut account_locks = ReadWriteAccountSet::default(); + let account = Pubkey::new_unique(); + assert!(account_locks.can_write(&account)); + account_locks.add_write(&account); + assert!(!account_locks.can_write(&account)); + } + + #[test] + pub fn test_read_write_conflict() { + let mut account_locks = ReadWriteAccountSet::default(); + let account = Pubkey::new_unique(); + assert!(account_locks.can_read(&account)); + account_locks.add_read(&account); + assert!(!account_locks.can_write(&account)); + assert!(account_locks.can_read(&account)); + } + + #[test] + pub fn test_write_read_conflict() { + let mut account_locks = ReadWriteAccountSet::default(); + let account = Pubkey::new_unique(); + assert!(account_locks.can_write(&account)); + account_locks.add_write(&account); + assert!(!account_locks.can_write(&account)); + assert!(!account_locks.can_read(&account)); + } + + #[test] + pub fn test_read_read_non_conflict() { + let mut account_locks = ReadWriteAccountSet::default(); + let account = Pubkey::new_unique(); + assert!(account_locks.can_read(&account)); + account_locks.add_read(&account); + assert!(account_locks.can_read(&account)); + } + + #[test] + pub fn test_write_write_different_keys() { + let mut account_locks = ReadWriteAccountSet::default(); + let account1 = Pubkey::new_unique(); + let account2 = Pubkey::new_unique(); + assert!(account_locks.can_write(&account1)); + account_locks.add_write(&account1); + assert!(account_locks.can_write(&account2)); + assert!(account_locks.can_read(&account2)); + } +} diff --git a/scheduler/src/impls/prio_graph_scheduler/scheduler_error.rs b/scheduler/src/impls/prio_graph_scheduler/scheduler_error.rs new file mode 100644 index 0000000..9b8d401 --- /dev/null +++ b/scheduler/src/impls/prio_graph_scheduler/scheduler_error.rs @@ -0,0 +1,9 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum SchedulerError { + #[error("Sending channel disconnected: {0}")] + DisconnectedSendChannel(&'static str), + #[error("Recv channel disconnected: {0}")] + DisconnectedRecvChannel(&'static str), +} diff --git a/scheduler/src/impls/prio_graph_scheduler/scheduler_metrics.rs b/scheduler/src/impls/prio_graph_scheduler/scheduler_metrics.rs new file mode 100644 index 0000000..2c0e39d --- /dev/null +++ b/scheduler/src/impls/prio_graph_scheduler/scheduler_metrics.rs @@ -0,0 +1,409 @@ +use solana_metrics::{create_datapoint, datapoint_info}; +use { + itertools::MinMaxResult, + solana_poh::poh_recorder::BankStart, + solana_sdk::{clock::Slot, timing::AtomicInterval}, + std::time::Instant, +}; + +#[derive(Default)] +pub struct SchedulerCountMetrics { + interval: IntervalSchedulerCountMetrics, + slot: SlotSchedulerCountMetrics, +} + +impl SchedulerCountMetrics { + pub fn update(&mut self, update: impl Fn(&mut SchedulerCountMetricsInner)) { + update(&mut self.interval.metrics); + update(&mut self.slot.metrics); + } + + pub fn maybe_report_and_reset_slot(&mut self, slot: Option) { + self.slot.maybe_report_and_reset(slot); + } + + pub fn maybe_report_and_reset_interval(&mut self, should_report: bool) { + self.interval.maybe_report_and_reset(should_report); + } + + pub fn interval_has_data(&self) -> bool { + self.interval.metrics.has_data() + } +} + +#[derive(Default)] +struct IntervalSchedulerCountMetrics { + interval: AtomicInterval, + metrics: SchedulerCountMetricsInner, +} + +#[derive(Default)] +struct SlotSchedulerCountMetrics { + slot: Option, + metrics: SchedulerCountMetricsInner, +} + +#[derive(Default)] +pub struct SchedulerCountMetricsInner { + /// Number of packets received. + pub num_received: usize, + /// Number of packets buffered. + pub num_buffered: usize, + + /// Number of transactions scheduled. + pub num_scheduled: usize, + /// Number of transactions that were unschedulable. + pub num_unschedulable: usize, + /// Number of transactions that were filtered out during scheduling. + pub num_schedule_filtered_out: usize, + /// Number of completed transactions received from workers. + pub num_finished: usize, + /// Number of transactions that were retryable. + pub num_retryable: usize, + /// Number of transactions that were scheduled to be forwarded. + pub num_forwarded: usize, + + /// Number of transactions that were immediately dropped on receive. + pub num_dropped_on_receive: usize, + /// Number of transactions that were dropped due to sanitization failure. + pub num_dropped_on_sanitization: usize, + /// Number of transactions that were dropped due to failed lock validation. + pub num_dropped_on_validate_locks: usize, + /// Number of transactions that were dropped due to failed transaction + /// checks during receive. + pub num_dropped_on_receive_transaction_checks: usize, + /// Number of transactions that were dropped due to clearing. + pub num_dropped_on_clear: usize, + /// Number of transactions that were dropped due to age and status checks. + pub num_dropped_on_age_and_status: usize, + /// Number of transactions that were dropped due to exceeded capacity. + pub num_dropped_on_capacity: usize, + /// Min prioritization fees in the transaction container + pub min_prioritization_fees: u64, + /// Max prioritization fees in the transaction container + pub max_prioritization_fees: u64, +} + +impl IntervalSchedulerCountMetrics { + fn maybe_report_and_reset(&mut self, should_report: bool) { + const REPORT_INTERVAL_MS: u64 = 1000; + if self.interval.should_update(REPORT_INTERVAL_MS) { + if should_report { + self.metrics.report("banking_stage_scheduler_counts", None); + } + self.metrics.reset(); + } + } +} + +impl SlotSchedulerCountMetrics { + fn maybe_report_and_reset(&mut self, slot: Option) { + if self.slot != slot { + // Only report if there was an assigned slot. + if self.slot.is_some() { + self.metrics + .report("banking_stage_scheduler_slot_counts", self.slot); + } + self.metrics.reset(); + self.slot = slot; + } + } +} + +impl SchedulerCountMetricsInner { + fn report(&self, name: &'static str, slot: Option) { + let mut datapoint = create_datapoint!( + @point name, + ("num_received", self.num_received, i64), + ("num_buffered", self.num_buffered, i64), + ("num_scheduled", self.num_scheduled, i64), + ("num_unschedulable", self.num_unschedulable, i64), + ( + "num_schedule_filtered_out", + self.num_schedule_filtered_out, + i64 + ), + ("num_finished", self.num_finished, i64), + ("num_retryable", self.num_retryable, i64), + ("num_forwarded", self.num_forwarded, i64), + ("num_dropped_on_receive", self.num_dropped_on_receive, i64), + ( + "num_dropped_on_sanitization", + self.num_dropped_on_sanitization, + i64 + ), + ( + "num_dropped_on_validate_locks", + self.num_dropped_on_validate_locks, + i64 + ), + ( + "num_dropped_on_receive_transaction_checks", + self.num_dropped_on_receive_transaction_checks, + i64 + ), + ("num_dropped_on_clear", self.num_dropped_on_clear, i64), + ( + "num_dropped_on_age_and_status", + self.num_dropped_on_age_and_status, + i64 + ), + ("num_dropped_on_capacity", self.num_dropped_on_capacity, i64), + ("min_priority", self.get_min_priority(), i64), + ("max_priority", self.get_max_priority(), i64) + ); + if let Some(slot) = slot { + datapoint.add_field_i64("slot", slot as i64); + } + solana_metrics::submit(datapoint, log::Level::Info); + } + + pub fn has_data(&self) -> bool { + self.num_received != 0 + || self.num_buffered != 0 + || self.num_scheduled != 0 + || self.num_unschedulable != 0 + || self.num_schedule_filtered_out != 0 + || self.num_finished != 0 + || self.num_retryable != 0 + || self.num_forwarded != 0 + || self.num_dropped_on_receive != 0 + || self.num_dropped_on_sanitization != 0 + || self.num_dropped_on_validate_locks != 0 + || self.num_dropped_on_receive_transaction_checks != 0 + || self.num_dropped_on_clear != 0 + || self.num_dropped_on_age_and_status != 0 + || self.num_dropped_on_capacity != 0 + } + + fn reset(&mut self) { + self.num_received = 0; + self.num_buffered = 0; + self.num_scheduled = 0; + self.num_unschedulable = 0; + self.num_schedule_filtered_out = 0; + self.num_finished = 0; + self.num_retryable = 0; + self.num_forwarded = 0; + self.num_dropped_on_receive = 0; + self.num_dropped_on_sanitization = 0; + self.num_dropped_on_validate_locks = 0; + self.num_dropped_on_receive_transaction_checks = 0; + self.num_dropped_on_clear = 0; + self.num_dropped_on_age_and_status = 0; + self.num_dropped_on_capacity = 0; + self.min_prioritization_fees = u64::MAX; + self.max_prioritization_fees = 0; + } + + pub fn update_priority_stats(&mut self, min_max_fees: MinMaxResult) { + // update min/max priority + match min_max_fees { + itertools::MinMaxResult::NoElements => { + // do nothing + } + itertools::MinMaxResult::OneElement(e) => { + self.min_prioritization_fees = e; + self.max_prioritization_fees = e; + } + itertools::MinMaxResult::MinMax(min, max) => { + self.min_prioritization_fees = min; + self.max_prioritization_fees = max; + } + } + } + + pub fn get_min_priority(&self) -> u64 { + // to avoid getting u64::max recorded by metrics / in case of edge cases + if self.min_prioritization_fees != u64::MAX { + self.min_prioritization_fees + } else { + 0 + } + } + + pub fn get_max_priority(&self) -> u64 { + self.max_prioritization_fees + } +} + +#[derive(Default)] +pub struct SchedulerTimingMetrics { + interval: IntervalSchedulerTimingMetrics, + slot: SlotSchedulerTimingMetrics, +} + +impl SchedulerTimingMetrics { + pub fn update(&mut self, update: impl Fn(&mut SchedulerTimingMetricsInner)) { + update(&mut self.interval.metrics); + update(&mut self.slot.metrics); + } + + pub fn maybe_report_and_reset_slot(&mut self, slot: Option) { + self.slot.maybe_report_and_reset(slot); + } + + pub fn maybe_report_and_reset_interval(&mut self, should_report: bool) { + self.interval.maybe_report_and_reset(should_report); + } +} + +#[derive(Default)] +struct IntervalSchedulerTimingMetrics { + interval: AtomicInterval, + metrics: SchedulerTimingMetricsInner, +} + +#[derive(Default)] +struct SlotSchedulerTimingMetrics { + slot: Option, + metrics: SchedulerTimingMetricsInner, +} + +#[derive(Default)] +pub struct SchedulerTimingMetricsInner { + /// Time spent making processing decisions. + pub decision_time_us: u64, + /// Time spent receiving packets. + pub receive_time_us: u64, + /// Time spent buffering packets. + pub buffer_time_us: u64, + /// Time spent filtering transactions during scheduling. + pub schedule_filter_time_us: u64, + /// Time spent scheduling transactions. + pub schedule_time_us: u64, + /// Time spent clearing transactions from the container. + pub clear_time_us: u64, + /// Time spent cleaning expired or processed transactions from the container. + pub clean_time_us: u64, + /// Time spent forwarding transactions. + pub forward_time_us: u64, + /// Time spent receiving completed transactions. + pub receive_completed_time_us: u64, +} + +impl IntervalSchedulerTimingMetrics { + fn maybe_report_and_reset(&mut self, should_report: bool) { + const REPORT_INTERVAL_MS: u64 = 1000; + if self.interval.should_update(REPORT_INTERVAL_MS) { + if should_report { + self.metrics.report("banking_stage_scheduler_timing", None); + } + self.metrics.reset(); + } + } +} + +impl SlotSchedulerTimingMetrics { + fn maybe_report_and_reset(&mut self, slot: Option) { + if self.slot != slot { + // Only report if there was an assigned slot. + if self.slot.is_some() { + self.metrics + .report("banking_stage_scheduler_slot_timing", self.slot); + } + self.metrics.reset(); + self.slot = slot; + } + } +} + +impl SchedulerTimingMetricsInner { + fn report(&self, name: &'static str, slot: Option) { + let mut datapoint = create_datapoint!( + @point name, + ("decision_time_us", self.decision_time_us, i64), + ("receive_time_us", self.receive_time_us, i64), + ("buffer_time_us", self.buffer_time_us, i64), + ("schedule_filter_time_us", self.schedule_filter_time_us, i64), + ("schedule_time_us", self.schedule_time_us, i64), + ("clear_time_us", self.clear_time_us, i64), + ("clean_time_us", self.clean_time_us, i64), + ("forward_time_us", self.forward_time_us, i64), + ( + "receive_completed_time_us", + self.receive_completed_time_us, + i64 + ) + ); + if let Some(slot) = slot { + datapoint.add_field_i64("slot", slot as i64); + } + solana_metrics::submit(datapoint, log::Level::Info); + } + + fn reset(&mut self) { + self.decision_time_us = 0; + self.receive_time_us = 0; + self.buffer_time_us = 0; + self.schedule_filter_time_us = 0; + self.schedule_time_us = 0; + self.clear_time_us = 0; + self.clean_time_us = 0; + self.forward_time_us = 0; + self.receive_completed_time_us = 0; + } +} + +#[derive(Default)] +pub struct SchedulerLeaderDetectionMetrics { + inner: Option, +} + +struct SchedulerLeaderDetectionMetricsInner { + slot: Slot, + bank_creation_time: Instant, + bank_detected_time: Instant, +} + +impl SchedulerLeaderDetectionMetrics { + pub fn update_and_maybe_report(&mut self, bank_start: Option<&BankStart>) { + match (&self.inner, bank_start) { + (None, Some(bank_start)) => self.initialize_inner(bank_start), + (Some(_inner), None) => self.report_and_reset(), + (Some(inner), Some(bank_start)) if inner.slot != bank_start.working_bank.slot() => { + self.report_and_reset(); + self.initialize_inner(bank_start); + } + _ => {} + } + } + + fn initialize_inner(&mut self, bank_start: &BankStart) { + let bank_detected_time = Instant::now(); + self.inner = Some(SchedulerLeaderDetectionMetricsInner { + slot: bank_start.working_bank.slot(), + bank_creation_time: *bank_start.bank_creation_time, + bank_detected_time, + }); + } + + fn report_and_reset(&mut self) { + let SchedulerLeaderDetectionMetricsInner { + slot, + bank_creation_time, + bank_detected_time, + } = self.inner.take().expect("inner must be present"); + + let bank_detected_delay_us = bank_detected_time + .duration_since(bank_creation_time) + .as_micros() + .try_into() + .unwrap_or(i64::MAX); + let bank_detected_to_slot_end_detected_us = bank_detected_time + .elapsed() + .as_micros() + .try_into() + .unwrap_or(i64::MAX); + datapoint_info!( + "banking_stage_scheduler_leader_detection", + ("slot", slot, i64), + ("bank_detected_delay_us", bank_detected_delay_us, i64), + ( + "bank_detected_to_slot_end_detected_us", + bank_detected_to_slot_end_detected_us, + i64 + ), + ); + } +} diff --git a/scheduler/src/impls/prio_graph_scheduler/thread_aware_account_locks.rs b/scheduler/src/impls/prio_graph_scheduler/thread_aware_account_locks.rs new file mode 100644 index 0000000..d8563c6 --- /dev/null +++ b/scheduler/src/impls/prio_graph_scheduler/thread_aware_account_locks.rs @@ -0,0 +1,742 @@ +use { + ahash::AHashMap, + solana_sdk::pubkey::Pubkey, + std::{ + collections::hash_map::Entry, + fmt::{Debug, Display}, + ops::{BitAnd, BitAndAssign, Sub}, + }, +}; + +pub const MAX_THREADS: usize = u64::BITS as usize; + +/// Identifier for a thread +pub type ThreadId = usize; // 0..MAX_THREADS-1 + +type LockCount = u32; + +/// A bit-set of threads an account is scheduled or can be scheduled for. +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct ThreadSet(u64); + +struct AccountWriteLocks { + thread_id: ThreadId, + lock_count: LockCount, +} + +struct AccountReadLocks { + thread_set: ThreadSet, + lock_counts: [LockCount; MAX_THREADS], +} + +/// Account locks. +/// Write Locks - only one thread can hold a write lock at a time. +/// Contains how many write locks are held by the thread. +/// Read Locks - multiple threads can hold a read lock at a time. +/// Contains thread-set for easily checking which threads are scheduled. +#[derive(Default)] +struct AccountLocks { + pub write_locks: Option, + pub read_locks: Option, +} + +/// Thread-aware account locks which allows for scheduling on threads +/// that already hold locks on the account. This is useful for allowing +/// queued transactions to be scheduled on a thread while the transaction +/// is still being executed on the thread. +pub struct ThreadAwareAccountLocks { + /// Number of threads. + num_threads: usize, // 0..MAX_THREADS + /// Locks for each account. An account should only have an entry if there + /// is at least one lock. + locks: AHashMap, +} + +impl ThreadAwareAccountLocks { + /// Creates a new `ThreadAwareAccountLocks` with the given number of threads. + pub fn new(num_threads: usize) -> Self { + assert!(num_threads > 0, "num threads must be > 0"); + assert!( + num_threads <= MAX_THREADS, + "num threads must be <= {MAX_THREADS}" + ); + + Self { + num_threads, + locks: AHashMap::new(), + } + } + + /// Returns the `ThreadId` if the accounts are able to be locked + /// for the given thread, otherwise `None` is returned. + /// `allowed_threads` is a set of threads that the caller restricts locking to. + /// If accounts are schedulable, then they are locked for the thread + /// selected by the `thread_selector` function. + /// `thread_selector` is only called if all accounts are schdulable, meaning + /// that the `thread_set` passed to `thread_selector` is non-empty. + pub fn try_lock_accounts<'a>( + &mut self, + write_account_locks: impl Iterator + Clone, + read_account_locks: impl Iterator + Clone, + allowed_threads: ThreadSet, + thread_selector: impl FnOnce(ThreadSet) -> ThreadId, + ) -> Option { + let schedulable_threads = self.accounts_schedulable_threads( + write_account_locks.clone(), + read_account_locks.clone(), + )? & allowed_threads; + (!schedulable_threads.is_empty()).then(|| { + let thread_id = thread_selector(schedulable_threads); + self.lock_accounts(write_account_locks, read_account_locks, thread_id); + thread_id + }) + } + + /// Unlocks the accounts for the given thread. + pub fn unlock_accounts<'a>( + &mut self, + write_account_locks: impl Iterator, + read_account_locks: impl Iterator, + thread_id: ThreadId, + ) { + for account in write_account_locks { + self.write_unlock_account(account, thread_id); + } + + for account in read_account_locks { + self.read_unlock_account(account, thread_id); + } + } + + /// Returns `ThreadSet` that the given accounts can be scheduled on. + fn accounts_schedulable_threads<'a>( + &self, + write_account_locks: impl Iterator, + read_account_locks: impl Iterator, + ) -> Option { + let mut schedulable_threads = ThreadSet::any(self.num_threads); + + for account in write_account_locks { + schedulable_threads &= self.write_schedulable_threads(account); + if schedulable_threads.is_empty() { + return None; + } + } + + for account in read_account_locks { + schedulable_threads &= self.read_schedulable_threads(account); + if schedulable_threads.is_empty() { + return None; + } + } + + Some(schedulable_threads) + } + + /// Returns `ThreadSet` of schedulable threads for the given readable account. + fn read_schedulable_threads(&self, account: &Pubkey) -> ThreadSet { + self.schedulable_threads::(account) + } + + /// Returns `ThreadSet` of schedulable threads for the given writable account. + fn write_schedulable_threads(&self, account: &Pubkey) -> ThreadSet { + self.schedulable_threads::(account) + } + + /// Returns `ThreadSet` of schedulable threads. + /// If there are no locks, then all threads are schedulable. + /// If only write-locked, then only the thread holding the write lock is schedulable. + /// If a mix of locks, then only the write thread is schedulable. + /// If only read-locked, the only write-schedulable thread is if a single thread + /// holds all read locks. Otherwise, no threads are write-schedulable. + /// If only read-locked, all threads are read-schedulable. + fn schedulable_threads(&self, account: &Pubkey) -> ThreadSet { + match self.locks.get(account) { + None => ThreadSet::any(self.num_threads), + Some(AccountLocks { + write_locks: None, + read_locks: Some(read_locks), + }) => { + if WRITE { + read_locks + .thread_set + .only_one_contained() + .map(ThreadSet::only) + .unwrap_or_else(ThreadSet::none) + } else { + ThreadSet::any(self.num_threads) + } + } + Some(AccountLocks { + write_locks: Some(write_locks), + read_locks: None, + }) => ThreadSet::only(write_locks.thread_id), + Some(AccountLocks { + write_locks: Some(write_locks), + read_locks: Some(read_locks), + }) => { + assert_eq!( + read_locks.thread_set.only_one_contained(), + Some(write_locks.thread_id) + ); + read_locks.thread_set + } + Some(AccountLocks { + write_locks: None, + read_locks: None, + }) => unreachable!(), + } + } + + /// Add locks for all writable and readable accounts on `thread_id`. + fn lock_accounts<'a>( + &mut self, + write_account_locks: impl Iterator, + read_account_locks: impl Iterator, + thread_id: ThreadId, + ) { + assert!( + thread_id < self.num_threads, + "thread_id must be < num_threads" + ); + for account in write_account_locks { + self.write_lock_account(account, thread_id); + } + + for account in read_account_locks { + self.read_lock_account(account, thread_id); + } + } + + /// Locks the given `account` for writing on `thread_id`. + /// Panics if the account is already locked for writing on another thread. + fn write_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { + let entry = self.locks.entry(*account).or_default(); + + let AccountLocks { + write_locks, + read_locks, + } = entry; + + if let Some(read_locks) = read_locks { + assert_eq!( + read_locks.thread_set.only_one_contained(), + Some(thread_id), + "outstanding read lock must be on same thread" + ); + } + + if let Some(write_locks) = write_locks { + assert_eq!( + write_locks.thread_id, thread_id, + "outstanding write lock must be on same thread" + ); + write_locks.lock_count += 1; + } else { + *write_locks = Some(AccountWriteLocks { + thread_id, + lock_count: 1, + }); + } + } + + /// Unlocks the given `account` for writing on `thread_id`. + /// Panics if the account is not locked for writing on `thread_id`. + fn write_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { + let Entry::Occupied(mut entry) = self.locks.entry(*account) else { + panic!("write lock must exist for account: {account}"); + }; + + let AccountLocks { + write_locks: maybe_write_locks, + read_locks, + } = entry.get_mut(); + + let Some(write_locks) = maybe_write_locks else { + panic!("write lock must exist for account: {account}"); + }; + + assert_eq!( + write_locks.thread_id, thread_id, + "outstanding write lock must be on same thread" + ); + + write_locks.lock_count -= 1; + if write_locks.lock_count == 0 { + *maybe_write_locks = None; + if read_locks.is_none() { + entry.remove(); + } + } + } + + /// Locks the given `account` for reading on `thread_id`. + /// Panics if the account is already locked for writing on another thread. + fn read_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { + let AccountLocks { + write_locks, + read_locks, + } = self.locks.entry(*account).or_default(); + + if let Some(write_locks) = write_locks { + assert_eq!( + write_locks.thread_id, thread_id, + "outstanding write lock must be on same thread" + ); + } + + match read_locks { + Some(read_locks) => { + read_locks.thread_set.insert(thread_id); + read_locks.lock_counts[thread_id] += 1; + } + None => { + let mut lock_counts = [0; MAX_THREADS]; + lock_counts[thread_id] = 1; + *read_locks = Some(AccountReadLocks { + thread_set: ThreadSet::only(thread_id), + lock_counts, + }); + } + } + } + + /// Unlocks the given `account` for reading on `thread_id`. + /// Panics if the account is not locked for reading on `thread_id`. + fn read_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { + let Entry::Occupied(mut entry) = self.locks.entry(*account) else { + panic!("read lock must exist for account: {account}"); + }; + + let AccountLocks { + write_locks, + read_locks: maybe_read_locks, + } = entry.get_mut(); + + let Some(read_locks) = maybe_read_locks else { + panic!("read lock must exist for account: {account}"); + }; + + assert!( + read_locks.thread_set.contains(thread_id), + "outstanding read lock must be on same thread" + ); + + read_locks.lock_counts[thread_id] -= 1; + if read_locks.lock_counts[thread_id] == 0 { + read_locks.thread_set.remove(thread_id); + if read_locks.thread_set.is_empty() { + *maybe_read_locks = None; + if write_locks.is_none() { + entry.remove(); + } + } + } + } +} + +impl BitAnd for ThreadSet { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + Self(self.0 & rhs.0) + } +} + +impl BitAndAssign for ThreadSet { + fn bitand_assign(&mut self, rhs: Self) { + self.0 &= rhs.0; + } +} + +impl Sub for ThreadSet { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self(self.0 & !rhs.0) + } +} + +impl Display for ThreadSet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ThreadSet({:#0width$b})", self.0, width = MAX_THREADS) + } +} + +impl Debug for ThreadSet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } +} + +impl ThreadSet { + #[inline(always)] + pub const fn none() -> Self { + Self(0b0) + } + + #[inline(always)] + pub const fn any(num_threads: usize) -> Self { + if num_threads == MAX_THREADS { + Self(u64::MAX) + } else { + Self(Self::as_flag(num_threads) - 1) + } + } + + #[inline(always)] + pub const fn only(thread_id: ThreadId) -> Self { + Self(Self::as_flag(thread_id)) + } + + #[inline(always)] + pub fn num_threads(&self) -> u32 { + self.0.count_ones() + } + + #[inline(always)] + pub fn only_one_contained(&self) -> Option { + (self.num_threads() == 1).then_some(self.0.trailing_zeros() as ThreadId) + } + + #[inline(always)] + pub fn is_empty(&self) -> bool { + self == &Self::none() + } + + #[inline(always)] + pub fn contains(&self, thread_id: ThreadId) -> bool { + self.0 & Self::as_flag(thread_id) != 0 + } + + #[inline(always)] + pub fn insert(&mut self, thread_id: ThreadId) { + self.0 |= Self::as_flag(thread_id); + } + + #[inline(always)] + pub fn remove(&mut self, thread_id: ThreadId) { + self.0 &= !Self::as_flag(thread_id); + } + + #[inline(always)] + pub fn contained_threads_iter(self) -> impl Iterator { + (0..MAX_THREADS).filter(move |thread_id| self.contains(*thread_id)) + } + + #[inline(always)] + const fn as_flag(thread_id: ThreadId) -> u64 { + 0b1 << thread_id + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const TEST_NUM_THREADS: usize = 4; + const TEST_ANY_THREADS: ThreadSet = ThreadSet::any(TEST_NUM_THREADS); + + // Simple thread selector to select the first schedulable thread + fn test_thread_selector(thread_set: ThreadSet) -> ThreadId { + thread_set.contained_threads_iter().next().unwrap() + } + + #[test] + #[should_panic(expected = "num threads must be > 0")] + fn test_too_few_num_threads() { + ThreadAwareAccountLocks::new(0); + } + + #[test] + #[should_panic(expected = "num threads must be <=")] + fn test_too_many_num_threads() { + ThreadAwareAccountLocks::new(MAX_THREADS + 1); + } + + #[test] + fn test_try_lock_accounts_none() { + let pk1 = Pubkey::new_unique(); + let pk2 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.read_lock_account(&pk1, 2); + locks.read_lock_account(&pk1, 3); + assert_eq!( + locks.try_lock_accounts( + [&pk1].into_iter(), + [&pk2].into_iter(), + TEST_ANY_THREADS, + test_thread_selector + ), + None + ); + } + + #[test] + fn test_try_lock_accounts_one() { + let pk1 = Pubkey::new_unique(); + let pk2 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.write_lock_account(&pk2, 3); + + assert_eq!( + locks.try_lock_accounts( + [&pk1].into_iter(), + [&pk2].into_iter(), + TEST_ANY_THREADS, + test_thread_selector + ), + Some(3) + ); + } + + #[test] + fn test_try_lock_accounts_multiple() { + let pk1 = Pubkey::new_unique(); + let pk2 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.read_lock_account(&pk2, 0); + locks.read_lock_account(&pk2, 0); + + assert_eq!( + locks.try_lock_accounts( + [&pk1].into_iter(), + [&pk2].into_iter(), + TEST_ANY_THREADS - ThreadSet::only(0), // exclude 0 + test_thread_selector + ), + Some(1) + ); + } + + #[test] + fn test_try_lock_accounts_any() { + let pk1 = Pubkey::new_unique(); + let pk2 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + assert_eq!( + locks.try_lock_accounts( + [&pk1].into_iter(), + [&pk2].into_iter(), + TEST_ANY_THREADS, + test_thread_selector + ), + Some(0) + ); + } + + #[test] + fn test_accounts_schedulable_threads_no_outstanding_locks() { + let pk1 = Pubkey::new_unique(); + let locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + + assert_eq!( + locks.accounts_schedulable_threads([&pk1].into_iter(), std::iter::empty()), + Some(TEST_ANY_THREADS) + ); + assert_eq!( + locks.accounts_schedulable_threads(std::iter::empty(), [&pk1].into_iter()), + Some(TEST_ANY_THREADS) + ); + } + + #[test] + fn test_accounts_schedulable_threads_outstanding_write_only() { + let pk1 = Pubkey::new_unique(); + let pk2 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + + locks.write_lock_account(&pk1, 2); + assert_eq!( + locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), + Some(ThreadSet::only(2)) + ); + assert_eq!( + locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), + Some(ThreadSet::only(2)) + ); + } + + #[test] + fn test_accounts_schedulable_threads_outstanding_read_only() { + let pk1 = Pubkey::new_unique(); + let pk2 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + + locks.read_lock_account(&pk1, 2); + assert_eq!( + locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), + Some(ThreadSet::only(2)) + ); + assert_eq!( + locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), + Some(TEST_ANY_THREADS) + ); + + locks.read_lock_account(&pk1, 0); + assert_eq!( + locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), + None + ); + assert_eq!( + locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), + Some(TEST_ANY_THREADS) + ); + } + + #[test] + fn test_accounts_schedulable_threads_outstanding_mixed() { + let pk1 = Pubkey::new_unique(); + let pk2 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + + locks.read_lock_account(&pk1, 2); + locks.write_lock_account(&pk1, 2); + assert_eq!( + locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), + Some(ThreadSet::only(2)) + ); + assert_eq!( + locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), + Some(ThreadSet::only(2)) + ); + } + + #[test] + #[should_panic(expected = "outstanding write lock must be on same thread")] + fn test_write_lock_account_write_conflict_panic() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.write_lock_account(&pk1, 0); + locks.write_lock_account(&pk1, 1); + } + + #[test] + #[should_panic(expected = "outstanding read lock must be on same thread")] + fn test_write_lock_account_read_conflict_panic() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.read_lock_account(&pk1, 0); + locks.write_lock_account(&pk1, 1); + } + + #[test] + #[should_panic(expected = "write lock must exist")] + fn test_write_unlock_account_not_locked() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.write_unlock_account(&pk1, 0); + } + + #[test] + #[should_panic(expected = "outstanding write lock must be on same thread")] + fn test_write_unlock_account_thread_mismatch() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.write_lock_account(&pk1, 1); + locks.write_unlock_account(&pk1, 0); + } + + #[test] + #[should_panic(expected = "outstanding write lock must be on same thread")] + fn test_read_lock_account_write_conflict_panic() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.write_lock_account(&pk1, 0); + locks.read_lock_account(&pk1, 1); + } + + #[test] + #[should_panic(expected = "read lock must exist")] + fn test_read_unlock_account_not_locked() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.read_unlock_account(&pk1, 1); + } + + #[test] + #[should_panic(expected = "outstanding read lock must be on same thread")] + fn test_read_unlock_account_thread_mismatch() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.read_lock_account(&pk1, 0); + locks.read_unlock_account(&pk1, 1); + } + + #[test] + fn test_write_locking() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.write_lock_account(&pk1, 1); + locks.write_lock_account(&pk1, 1); + locks.write_unlock_account(&pk1, 1); + locks.write_unlock_account(&pk1, 1); + assert!(locks.locks.is_empty()); + } + + #[test] + fn test_read_locking() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.read_lock_account(&pk1, 1); + locks.read_lock_account(&pk1, 1); + locks.read_unlock_account(&pk1, 1); + locks.read_unlock_account(&pk1, 1); + assert!(locks.locks.is_empty()); + } + + #[test] + #[should_panic(expected = "thread_id must be < num_threads")] + fn test_lock_accounts_invalid_thread() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.lock_accounts([&pk1].into_iter(), std::iter::empty(), TEST_NUM_THREADS); + } + + #[test] + fn test_thread_set() { + let mut thread_set = ThreadSet::none(); + assert!(thread_set.is_empty()); + assert_eq!(thread_set.num_threads(), 0); + assert_eq!(thread_set.only_one_contained(), None); + for idx in 0..MAX_THREADS { + assert!(!thread_set.contains(idx)); + } + + thread_set.insert(4); + assert!(!thread_set.is_empty()); + assert_eq!(thread_set.num_threads(), 1); + assert_eq!(thread_set.only_one_contained(), Some(4)); + for idx in 0..MAX_THREADS { + assert_eq!(thread_set.contains(idx), idx == 4); + } + + thread_set.insert(2); + assert!(!thread_set.is_empty()); + assert_eq!(thread_set.num_threads(), 2); + assert_eq!(thread_set.only_one_contained(), None); + for idx in 0..MAX_THREADS { + assert_eq!(thread_set.contains(idx), idx == 2 || idx == 4); + } + + thread_set.remove(4); + assert!(!thread_set.is_empty()); + assert_eq!(thread_set.num_threads(), 1); + assert_eq!(thread_set.only_one_contained(), Some(2)); + for idx in 0..MAX_THREADS { + assert_eq!(thread_set.contains(idx), idx == 2); + } + } + + #[test] + fn test_thread_set_any_zero() { + let any_threads = ThreadSet::any(0); + assert_eq!(any_threads.num_threads(), 0); + } + + #[test] + fn test_thread_set_any_max() { + let any_threads = ThreadSet::any(MAX_THREADS); + assert_eq!(any_threads.num_threads(), MAX_THREADS as u32); + } +} diff --git a/scheduler/src/impls/prio_graph_scheduler/transaction_priority_id.rs b/scheduler/src/impls/prio_graph_scheduler/transaction_priority_id.rs new file mode 100644 index 0000000..05720c8 --- /dev/null +++ b/scheduler/src/impls/prio_graph_scheduler/transaction_priority_id.rs @@ -0,0 +1,69 @@ +use { + crate::scheduler_messages::TransactionId, + prio_graph::TopLevelId, + std::hash::{Hash, Hasher}, +}; + +/// A unique identifier tied with priority ordering for a transaction/packet: +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct TransactionPriorityId { + pub priority: u64, + pub id: TransactionId, +} + +impl TransactionPriorityId { + pub fn new(priority: u64, id: TransactionId) -> Self { + Self { priority, id } + } +} + +impl Hash for TransactionPriorityId { + fn hash(&self, state: &mut H) { + self.id.hash(state) + } +} + +impl TopLevelId for TransactionPriorityId { + fn id(&self) -> Self { + *self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_transaction_priority_id_ordering() { + // Higher priority first + { + let id1 = TransactionPriorityId::new(1, TransactionId::new(1)); + let id2 = TransactionPriorityId::new(2, TransactionId::new(1)); + assert!(id1 < id2); + assert!(id1 <= id2); + assert!(id2 > id1); + assert!(id2 >= id1); + } + + // Equal priority then compare by id + { + let id1 = TransactionPriorityId::new(1, TransactionId::new(1)); + let id2 = TransactionPriorityId::new(1, TransactionId::new(2)); + assert!(id1 < id2); + assert!(id1 <= id2); + assert!(id2 > id1); + assert!(id2 >= id1); + } + + // Equal priority and id + { + let id1 = TransactionPriorityId::new(1, TransactionId::new(1)); + let id2 = TransactionPriorityId::new(1, TransactionId::new(1)); + assert_eq!(id1, id2); + assert!(id1 >= id2); + assert!(id1 <= id2); + assert!(id2 >= id1); + assert!(id2 <= id1); + } + } +} diff --git a/scheduler/src/impls/prio_graph_scheduler/transaction_state.rs b/scheduler/src/impls/prio_graph_scheduler/transaction_state.rs new file mode 100644 index 0000000..be6bdc6 --- /dev/null +++ b/scheduler/src/impls/prio_graph_scheduler/transaction_state.rs @@ -0,0 +1,323 @@ +use crate::scheduler_messages::MaxAge; +use {solana_sdk::transaction::SanitizedTransaction, std::sync::Arc}; + +/// Simple wrapper type to tie a sanitized transaction to max age slot. +#[derive(Clone, Debug)] +pub struct SanitizedTransactionTTL { + pub transaction: SanitizedTransaction, + pub max_age: MaxAge, +} + +/// TransactionState is used to track the state of a transaction in the transaction scheduler +/// and banking stage as a whole. +/// +/// There are two states a transaction can be in: +/// 1. `Unprocessed` - The transaction is available for scheduling. +/// 2. `Pending` - The transaction is currently scheduled or being processed. +/// +/// Newly received transactions are initially in the `Unprocessed` state. +/// When a transaction is scheduled, it is transitioned to the `Pending` state, +/// using the `transition_to_pending` method. +/// When a transaction finishes processing it may be retryable. If it is retryable, +/// the transaction is transitioned back to the `Unprocessed` state using the +/// `transition_to_unprocessed` method. If it is not retryable, the state should +/// be dropped. +/// +/// For performance, when a transaction is transitioned to the `Pending` state, the +/// internal `SanitizedTransaction` is moved out of the `TransactionState` and sent +/// to the appropriate thread for processing. This is done to avoid cloning the +/// `SanitizedTransaction`. +#[allow(clippy::large_enum_variant)] +pub enum TransactionState { + /// The transaction is available for scheduling. + Unprocessed { + transaction_ttl: SanitizedTransactionTTL, + priority: u64, + cost: u64, + }, + /// The transaction is currently scheduled or being processed. + Pending { + transaction_ttl: SanitizedTransactionTTL, + priority: u64, + cost: u64, + }, + /// Only used during transition. + Transitioning, +} + +impl TransactionState { + /// Creates a new `TransactionState` in the `Unprocessed` state. + pub fn new(transaction_ttl: SanitizedTransactionTTL, priority: u64, cost: u64) -> Self { + Self::Unprocessed { + transaction_ttl, + priority, + cost, + } + } + + /// Return the priority of the transaction. + /// This is *not* the same as the `compute_unit_price` of the transaction. + /// The priority is used to order transactions for processing. + pub fn priority(&self) -> u64 { + match self { + Self::Unprocessed { priority, .. } => *priority, + Self::Pending { priority, .. } => *priority, + Self::Transitioning => unreachable!(), + } + } + + /// Return the cost of the transaction. + pub fn cost(&self) -> u64 { + match self { + Self::Unprocessed { cost, .. } => *cost, + Self::Pending { cost, .. } => *cost, + Self::Transitioning => unreachable!(), + } + } + + /// Intended to be called when a transaction is scheduled. This method will + /// transition the transaction from `Unprocessed` to `Pending` and return the + /// `SanitizedTransactionTTL` for processing. + /// + /// # Panics + /// This method will panic if the transaction is already in the `Pending` state, + /// as this is an invalid state transition. + pub fn transition_to_pending(&mut self) -> SanitizedTransactionTTL { + match self.take() { + TransactionState::Unprocessed { + transaction_ttl, + priority, + cost, + } => { + let transaction_ttl_c = transaction_ttl.clone(); + *self = TransactionState::Pending { + transaction_ttl, + priority, + cost, + }; + // TODO cancel this clone. + transaction_ttl_c + } + TransactionState::Pending { .. } => { + panic!("transaction already pending"); + } + Self::Transitioning => unreachable!(), + } + } + + /// Intended to be called when a transaction is retried. This method will + /// transition the transaction from `Pending` to `Unprocessed`. + /// + /// # Panics + /// This method will panic if the transaction is already in the `Unprocessed` + /// state, as this is an invalid state transition. + pub fn transition_to_unprocessed(&mut self, transaction_ttl: SanitizedTransactionTTL) { + match self.take() { + TransactionState::Unprocessed { .. } => panic!("already unprocessed"), + TransactionState::Pending { + transaction_ttl, + priority, + cost, + } => { + *self = Self::Unprocessed { + transaction_ttl, + priority, + cost, + } + } + Self::Transitioning => unreachable!(), + } + } + + /// Get a reference to the `SanitizedTransactionTTL` for the transaction. + /// + /// # Panics + /// This method will panic if the transaction is in the `Pending` state. + pub fn transaction_ttl(&self) -> &SanitizedTransactionTTL { + match self { + Self::Unprocessed { + transaction_ttl, .. + } => transaction_ttl, + Self::Pending { .. } => panic!("transaction is pending"), + Self::Transitioning => unreachable!(), + } + } + + /// Internal helper to transitioning between states. + /// Replaces `self` with a dummy state that will immediately be overwritten in transition. + fn take(&mut self) -> Self { + core::mem::replace(self, Self::Transitioning) + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + solana_sdk::{ + clock::Slot, compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, + signature::Keypair, signer::Signer, system_instruction, transaction::Transaction, + }, + }; + + fn create_transaction_state(compute_unit_price: u64) -> TransactionState { + let from_keypair = Keypair::new(); + let ixs = vec![ + system_instruction::transfer( + &from_keypair.pubkey(), + &solana_sdk::pubkey::new_rand(), + 1, + ), + ComputeBudgetInstruction::set_compute_unit_price(compute_unit_price), + ]; + let message = Message::new(&ixs, Some(&from_keypair.pubkey())); + let tx = Transaction::new(&[&from_keypair], message, Hash::default()); + + let transaction_ttl = SanitizedTransactionTTL { + transaction: SanitizedTransaction::from_transaction_for_tests(tx), + max_age: MaxAge { + epoch_invalidation_slot: Slot::MAX, + alt_invalidation_slot: Slot::MAX, + }, + }; + const TEST_TRANSACTION_COST: u64 = 5000; + TransactionState::new(transaction_ttl, compute_unit_price, TEST_TRANSACTION_COST) + } + + #[test] + #[should_panic(expected = "already pending")] + fn test_transition_to_pending_panic() { + let mut transaction_state = create_transaction_state(0); + println!("111"); + transaction_state.transition_to_pending(); + println!("222"); + transaction_state.transition_to_pending(); // invalid transition + println!("333"); + } + + #[test] + fn test_transition_to_pending() { + let mut transaction_state = create_transaction_state(0); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + let _ = transaction_state.transition_to_pending(); + assert!(matches!( + transaction_state, + TransactionState::Pending { .. } + )); + } + + #[test] + #[should_panic(expected = "already unprocessed")] + fn test_transition_to_unprocessed_panic() { + let mut transaction_state = create_transaction_state(0); + + // Manually clone `SanitizedTransactionTTL` + let SanitizedTransactionTTL { + transaction, + max_age, + } = transaction_state.transaction_ttl(); + let transaction_ttl = SanitizedTransactionTTL { + transaction: transaction.clone(), + max_age: *max_age, + }; + transaction_state.transition_to_unprocessed(transaction_ttl); // invalid transition + } + + #[test] + fn test_transition_to_unprocessed() { + let mut transaction_state = create_transaction_state(0); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + let transaction_ttl = transaction_state.transition_to_pending(); + assert!(matches!( + transaction_state, + TransactionState::Pending { .. } + )); + transaction_state.transition_to_unprocessed(transaction_ttl); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + } + + #[test] + fn test_priority() { + let priority = 15; + let mut transaction_state = create_transaction_state(priority); + assert_eq!(transaction_state.priority(), priority); + + // ensure compute unit price is not lost through state transitions + let transaction_ttl = transaction_state.transition_to_pending(); + assert_eq!(transaction_state.priority(), priority); + transaction_state.transition_to_unprocessed(transaction_ttl); + assert_eq!(transaction_state.priority(), priority); + } + + #[test] + #[should_panic(expected = "transaction is pending")] + fn test_transaction_ttl_panic() { + let mut transaction_state = create_transaction_state(0); + let transaction_ttl = transaction_state.transaction_ttl(); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + assert_eq!( + transaction_ttl.max_age, + MaxAge { + epoch_invalidation_slot: Slot::MAX, + alt_invalidation_slot: Slot::MAX, + } + ); + + let _ = transaction_state.transition_to_pending(); + assert!(matches!( + transaction_state, + TransactionState::Pending { .. } + )); + let _ = transaction_state.transaction_ttl(); // pending state, the transaction ttl is not available + } + + #[test] + fn test_transaction_ttl() { + let mut transaction_state = create_transaction_state(0); + let transaction_ttl = transaction_state.transaction_ttl(); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + assert_eq!( + transaction_ttl.max_age, + MaxAge { + epoch_invalidation_slot: Slot::MAX, + alt_invalidation_slot: Slot::MAX, + } + ); + + // ensure transaction_ttl is not lost through state transitions + let transaction_ttl = transaction_state.transition_to_pending(); + assert!(matches!( + transaction_state, + TransactionState::Pending { .. } + )); + + transaction_state.transition_to_unprocessed(transaction_ttl); + let transaction_ttl = transaction_state.transaction_ttl(); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + assert_eq!( + transaction_ttl.max_age, + MaxAge { + epoch_invalidation_slot: Slot::MAX, + alt_invalidation_slot: Slot::MAX, + } + ); + } +} diff --git a/scheduler/src/impls/prio_graph_scheduler/transaction_state_container.rs b/scheduler/src/impls/prio_graph_scheduler/transaction_state_container.rs new file mode 100644 index 0000000..b26feb6 --- /dev/null +++ b/scheduler/src/impls/prio_graph_scheduler/transaction_state_container.rs @@ -0,0 +1,263 @@ +// use { +// super::{ +// transaction_priority_id::TransactionPriorityId, +// transaction_state::{SanitizedTransactionTTL, TransactionState}, +// }, +// crate::scheduler_messages::TransactionId, +// itertools::MinMaxResult, +// min_max_heap::MinMaxHeap, +// std::{collections::HashMap, sync::Arc}, +// }; +// +// /// This structure will hold `TransactionState` for the entirety of a +// /// transaction's lifetime in the scheduler and BankingStage as a whole. +// /// +// /// Transaction Lifetime: +// /// 1. Received from `SigVerify` by `BankingStage` +// /// 2. Inserted into `TransactionStateContainer` by `BankingStage` +// /// 3. Popped in priority-order by scheduler, and transitioned to `Pending` state +// /// 4. Processed by `ConsumeWorker` +// /// a. If consumed, remove `Pending` state from the `TransactionStateContainer` +// /// b. If retryable, transition back to `Unprocessed` state. +// /// Re-insert to the queue, and return to step 3. +// /// +// /// The structure is composed of two main components: +// /// 1. A priority queue of wrapped `TransactionId`s, which are used to +// /// order transactions by priority for selection by the scheduler. +// /// 2. A map of `TransactionId` to `TransactionState`, which is used to +// /// track the state of each transaction. +// /// +// /// When `Pending`, the associated `TransactionId` is not in the queue, but +// /// is still in the map. +// /// The entry in the map should exist before insertion into the queue, and be +// /// be removed only after the id is removed from the queue. +// /// +// /// The container maintains a fixed capacity. If the queue is full when pushing +// /// a new transaction, the lowest priority transaction will be dropped. +// pub struct TransactionStateContainer { +// priority_queue: MinMaxHeap, +// id_to_transaction_state: HashMap>, +// } +// +// impl TransactionStateContainer

{ +// pub fn with_capacity(capacity: usize) -> Self { +// Self { +// priority_queue: MinMaxHeap::with_capacity(capacity), +// id_to_transaction_state: HashMap::with_capacity(capacity), +// } +// } +// +// /// Returns true if the queue is empty. +// pub fn is_empty(&self) -> bool { +// self.priority_queue.is_empty() +// } +// +// /// Returns the remaining capacity of the queue +// pub fn remaining_queue_capacity(&self) -> usize { +// self.priority_queue.capacity() - self.priority_queue.len() +// } +// +// /// Get the top transaction id in the priority queue. +// pub fn pop(&mut self) -> Option { +// self.priority_queue.pop_max() +// } +// +// /// Get mutable transaction state by id. +// pub fn get_mut_transaction_state( +// &mut self, +// id: &TransactionId, +// ) -> Option<&mut TransactionState

> { +// self.id_to_transaction_state.get_mut(id) +// } +// +// /// Get reference to `SanitizedTransactionTTL` by id. +// /// Panics if the transaction does not exist. +// pub fn get_transaction_ttl(&self, id: &TransactionId) -> Option<&SanitizedTransactionTTL> { +// self.id_to_transaction_state +// .get(id) +// .map(|state| state.transaction_ttl()) +// } +// +// /// Insert a new transaction into the container's queues and maps. +// /// Returns `true` if a packet was dropped due to capacity limits. +// pub fn insert_new_transaction( +// &mut self, +// transaction_id: TransactionId, +// transaction_ttl: SanitizedTransactionTTL, +// packet: Arc

, +// priority: u64, +// cost: u64, +// ) -> bool { +// let priority_id = TransactionPriorityId::new(priority, transaction_id); +// self.id_to_transaction_state.insert( +// transaction_id, +// TransactionState::new(transaction_ttl, packet, priority, cost), +// ); +// self.push_id_into_queue(priority_id) +// } +// +// /// Retries a transaction - inserts transaction back into map (but not packet). +// /// This transitions the transaction to `Unprocessed` state. +// pub fn retry_transaction( +// &mut self, +// transaction_id: TransactionId, +// transaction_ttl: SanitizedTransactionTTL, +// ) { +// let transaction_state = self +// .get_mut_transaction_state(&transaction_id) +// .expect("transaction must exist"); +// let priority_id = TransactionPriorityId::new(transaction_state.priority(), transaction_id); +// transaction_state.transition_to_unprocessed(transaction_ttl); +// self.push_id_into_queue(priority_id); +// } +// +// /// Pushes a transaction id into the priority queue. If the queue is full, the lowest priority +// /// transaction will be dropped (removed from the queue and map). +// /// Returns `true` if a packet was dropped due to capacity limits. +// pub fn push_id_into_queue(&mut self, priority_id: TransactionPriorityId) -> bool { +// if self.remaining_queue_capacity() == 0 { +// let popped_id = self.priority_queue.push_pop_min(priority_id); +// self.remove_by_id(&popped_id.id); +// true +// } else { +// self.priority_queue.push(priority_id); +// false +// } +// } +// +// /// Remove transaction by id. +// pub fn remove_by_id(&mut self, id: &TransactionId) { +// self.id_to_transaction_state +// .remove(id) +// .expect("transaction must exist"); +// } +// +// pub fn get_min_max_priority(&self) -> MinMaxResult { +// match self.priority_queue.peek_min() { +// Some(min) => match self.priority_queue.peek_max() { +// Some(max) => MinMaxResult::MinMax(min.priority, max.priority), +// None => MinMaxResult::OneElement(min.priority), +// }, +// None => MinMaxResult::NoElements, +// } +// } +// } +// +// #[cfg(test)] +// mod tests { +// use { +// super::*, +// crate::scheduler_messages::MaxAge, +// crate::tests::MockImmutableDeserializedPacket, +// solana_sdk::{ +// compute_budget::ComputeBudgetInstruction, +// hash::Hash, +// message::Message, +// packet::Packet, +// signature::Keypair, +// signer::Signer, +// slot_history::Slot, +// system_instruction, +// transaction::{SanitizedTransaction, Transaction}, +// }, +// }; +// +// /// Returns (transaction_ttl, priority, cost) +// fn test_transaction( +// priority: u64, +// ) -> ( +// SanitizedTransactionTTL, +// Arc, +// u64, +// u64, +// ) { +// let from_keypair = Keypair::new(); +// let ixs = vec![ +// system_instruction::transfer( +// &from_keypair.pubkey(), +// &solana_sdk::pubkey::new_rand(), +// 1, +// ), +// ComputeBudgetInstruction::set_compute_unit_price(priority), +// ]; +// let message = Message::new(&ixs, Some(&from_keypair.pubkey())); +// let tx = SanitizedTransaction::from_transaction_for_tests(Transaction::new( +// &[&from_keypair], +// message, +// Hash::default(), +// )); +// let packet = Arc::new( +// MockImmutableDeserializedPacket::new( +// Packet::from_data(None, tx.to_versioned_transaction()).unwrap(), +// ) +// .unwrap(), +// ); +// let transaction_ttl = SanitizedTransactionTTL { +// transaction: tx, +// max_age: MaxAge { +// epoch_invalidation_slot: Slot::MAX, +// alt_invalidation_slot: Slot::MAX, +// }, +// }; +// const TEST_TRANSACTION_COST: u64 = 5000; +// (transaction_ttl, packet, priority, TEST_TRANSACTION_COST) +// } +// +// fn push_to_container( +// container: &mut TransactionStateContainer, +// num: usize, +// ) { +// for id in 0..num as u64 { +// let priority = id; +// let (transaction_ttl, packet, priority, cost) = test_transaction(priority); +// container.insert_new_transaction( +// TransactionId::new(id), +// transaction_ttl, +// packet, +// priority, +// cost, +// ); +// } +// } +// +// #[test] +// fn test_is_empty() { +// let mut container = TransactionStateContainer::with_capacity(1); +// assert!(container.is_empty()); +// +// push_to_container(&mut container, 1); +// assert!(!container.is_empty()); +// } +// +// #[test] +// fn test_priority_queue_capacity() { +// let mut container = TransactionStateContainer::with_capacity(1); +// push_to_container(&mut container, 5); +// +// assert_eq!(container.priority_queue.len(), 1); +// assert_eq!(container.id_to_transaction_state.len(), 1); +// assert_eq!( +// container +// .id_to_transaction_state +// .iter() +// .map(|ts| ts.1.priority()) +// .next() +// .unwrap(), +// 4 +// ); +// } +// +// #[test] +// fn test_get_mut_transaction_state() { +// let mut container = TransactionStateContainer::with_capacity(5); +// push_to_container(&mut container, 5); +// +// let existing_id = TransactionId::new(3); +// let non_existing_id = TransactionId::new(7); +// assert!(container.get_mut_transaction_state(&existing_id).is_some()); +// assert!(container.get_mut_transaction_state(&existing_id).is_some()); +// assert!(container +// .get_mut_transaction_state(&non_existing_id) +// .is_none()); +// } +// } diff --git a/scheduler/src/scheduler_messages.rs b/scheduler/src/scheduler_messages.rs index c7870ed..8c2572e 100644 --- a/scheduler/src/scheduler_messages.rs +++ b/scheduler/src/scheduler_messages.rs @@ -1,3 +1,4 @@ +use solana_program::clock::Slot; use {solana_sdk::transaction::SanitizedTransaction, std::fmt::Display}; /// A unique identifier for a transaction batch. @@ -69,3 +70,11 @@ pub struct SchedulingBatchResult { // time slice status for this batch job. pub retryable_indexes: Vec, } + + +/// A TTL flag for a transaction. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct MaxAge { + pub epoch_invalidation_slot: Slot, + pub alt_invalidation_slot: Slot, +} \ No newline at end of file From 988e25645c603850a58dc648dea958aa653f1c5f Mon Sep 17 00:00:00 2001 From: lewis Date: Tue, 29 Oct 2024 14:30:07 +0800 Subject: [PATCH 3/4] feat: active pg-scheduler component --- .../transaction_state_container.rs | 515 +++++++++--------- 1 file changed, 252 insertions(+), 263 deletions(-) diff --git a/scheduler/src/impls/prio_graph_scheduler/transaction_state_container.rs b/scheduler/src/impls/prio_graph_scheduler/transaction_state_container.rs index b26feb6..17d80eb 100644 --- a/scheduler/src/impls/prio_graph_scheduler/transaction_state_container.rs +++ b/scheduler/src/impls/prio_graph_scheduler/transaction_state_container.rs @@ -1,263 +1,252 @@ -// use { -// super::{ -// transaction_priority_id::TransactionPriorityId, -// transaction_state::{SanitizedTransactionTTL, TransactionState}, -// }, -// crate::scheduler_messages::TransactionId, -// itertools::MinMaxResult, -// min_max_heap::MinMaxHeap, -// std::{collections::HashMap, sync::Arc}, -// }; -// -// /// This structure will hold `TransactionState` for the entirety of a -// /// transaction's lifetime in the scheduler and BankingStage as a whole. -// /// -// /// Transaction Lifetime: -// /// 1. Received from `SigVerify` by `BankingStage` -// /// 2. Inserted into `TransactionStateContainer` by `BankingStage` -// /// 3. Popped in priority-order by scheduler, and transitioned to `Pending` state -// /// 4. Processed by `ConsumeWorker` -// /// a. If consumed, remove `Pending` state from the `TransactionStateContainer` -// /// b. If retryable, transition back to `Unprocessed` state. -// /// Re-insert to the queue, and return to step 3. -// /// -// /// The structure is composed of two main components: -// /// 1. A priority queue of wrapped `TransactionId`s, which are used to -// /// order transactions by priority for selection by the scheduler. -// /// 2. A map of `TransactionId` to `TransactionState`, which is used to -// /// track the state of each transaction. -// /// -// /// When `Pending`, the associated `TransactionId` is not in the queue, but -// /// is still in the map. -// /// The entry in the map should exist before insertion into the queue, and be -// /// be removed only after the id is removed from the queue. -// /// -// /// The container maintains a fixed capacity. If the queue is full when pushing -// /// a new transaction, the lowest priority transaction will be dropped. -// pub struct TransactionStateContainer { -// priority_queue: MinMaxHeap, -// id_to_transaction_state: HashMap>, -// } -// -// impl TransactionStateContainer

{ -// pub fn with_capacity(capacity: usize) -> Self { -// Self { -// priority_queue: MinMaxHeap::with_capacity(capacity), -// id_to_transaction_state: HashMap::with_capacity(capacity), -// } -// } -// -// /// Returns true if the queue is empty. -// pub fn is_empty(&self) -> bool { -// self.priority_queue.is_empty() -// } -// -// /// Returns the remaining capacity of the queue -// pub fn remaining_queue_capacity(&self) -> usize { -// self.priority_queue.capacity() - self.priority_queue.len() -// } -// -// /// Get the top transaction id in the priority queue. -// pub fn pop(&mut self) -> Option { -// self.priority_queue.pop_max() -// } -// -// /// Get mutable transaction state by id. -// pub fn get_mut_transaction_state( -// &mut self, -// id: &TransactionId, -// ) -> Option<&mut TransactionState

> { -// self.id_to_transaction_state.get_mut(id) -// } -// -// /// Get reference to `SanitizedTransactionTTL` by id. -// /// Panics if the transaction does not exist. -// pub fn get_transaction_ttl(&self, id: &TransactionId) -> Option<&SanitizedTransactionTTL> { -// self.id_to_transaction_state -// .get(id) -// .map(|state| state.transaction_ttl()) -// } -// -// /// Insert a new transaction into the container's queues and maps. -// /// Returns `true` if a packet was dropped due to capacity limits. -// pub fn insert_new_transaction( -// &mut self, -// transaction_id: TransactionId, -// transaction_ttl: SanitizedTransactionTTL, -// packet: Arc

, -// priority: u64, -// cost: u64, -// ) -> bool { -// let priority_id = TransactionPriorityId::new(priority, transaction_id); -// self.id_to_transaction_state.insert( -// transaction_id, -// TransactionState::new(transaction_ttl, packet, priority, cost), -// ); -// self.push_id_into_queue(priority_id) -// } -// -// /// Retries a transaction - inserts transaction back into map (but not packet). -// /// This transitions the transaction to `Unprocessed` state. -// pub fn retry_transaction( -// &mut self, -// transaction_id: TransactionId, -// transaction_ttl: SanitizedTransactionTTL, -// ) { -// let transaction_state = self -// .get_mut_transaction_state(&transaction_id) -// .expect("transaction must exist"); -// let priority_id = TransactionPriorityId::new(transaction_state.priority(), transaction_id); -// transaction_state.transition_to_unprocessed(transaction_ttl); -// self.push_id_into_queue(priority_id); -// } -// -// /// Pushes a transaction id into the priority queue. If the queue is full, the lowest priority -// /// transaction will be dropped (removed from the queue and map). -// /// Returns `true` if a packet was dropped due to capacity limits. -// pub fn push_id_into_queue(&mut self, priority_id: TransactionPriorityId) -> bool { -// if self.remaining_queue_capacity() == 0 { -// let popped_id = self.priority_queue.push_pop_min(priority_id); -// self.remove_by_id(&popped_id.id); -// true -// } else { -// self.priority_queue.push(priority_id); -// false -// } -// } -// -// /// Remove transaction by id. -// pub fn remove_by_id(&mut self, id: &TransactionId) { -// self.id_to_transaction_state -// .remove(id) -// .expect("transaction must exist"); -// } -// -// pub fn get_min_max_priority(&self) -> MinMaxResult { -// match self.priority_queue.peek_min() { -// Some(min) => match self.priority_queue.peek_max() { -// Some(max) => MinMaxResult::MinMax(min.priority, max.priority), -// None => MinMaxResult::OneElement(min.priority), -// }, -// None => MinMaxResult::NoElements, -// } -// } -// } -// -// #[cfg(test)] -// mod tests { -// use { -// super::*, -// crate::scheduler_messages::MaxAge, -// crate::tests::MockImmutableDeserializedPacket, -// solana_sdk::{ -// compute_budget::ComputeBudgetInstruction, -// hash::Hash, -// message::Message, -// packet::Packet, -// signature::Keypair, -// signer::Signer, -// slot_history::Slot, -// system_instruction, -// transaction::{SanitizedTransaction, Transaction}, -// }, -// }; -// -// /// Returns (transaction_ttl, priority, cost) -// fn test_transaction( -// priority: u64, -// ) -> ( -// SanitizedTransactionTTL, -// Arc, -// u64, -// u64, -// ) { -// let from_keypair = Keypair::new(); -// let ixs = vec![ -// system_instruction::transfer( -// &from_keypair.pubkey(), -// &solana_sdk::pubkey::new_rand(), -// 1, -// ), -// ComputeBudgetInstruction::set_compute_unit_price(priority), -// ]; -// let message = Message::new(&ixs, Some(&from_keypair.pubkey())); -// let tx = SanitizedTransaction::from_transaction_for_tests(Transaction::new( -// &[&from_keypair], -// message, -// Hash::default(), -// )); -// let packet = Arc::new( -// MockImmutableDeserializedPacket::new( -// Packet::from_data(None, tx.to_versioned_transaction()).unwrap(), -// ) -// .unwrap(), -// ); -// let transaction_ttl = SanitizedTransactionTTL { -// transaction: tx, -// max_age: MaxAge { -// epoch_invalidation_slot: Slot::MAX, -// alt_invalidation_slot: Slot::MAX, -// }, -// }; -// const TEST_TRANSACTION_COST: u64 = 5000; -// (transaction_ttl, packet, priority, TEST_TRANSACTION_COST) -// } -// -// fn push_to_container( -// container: &mut TransactionStateContainer, -// num: usize, -// ) { -// for id in 0..num as u64 { -// let priority = id; -// let (transaction_ttl, packet, priority, cost) = test_transaction(priority); -// container.insert_new_transaction( -// TransactionId::new(id), -// transaction_ttl, -// packet, -// priority, -// cost, -// ); -// } -// } -// -// #[test] -// fn test_is_empty() { -// let mut container = TransactionStateContainer::with_capacity(1); -// assert!(container.is_empty()); -// -// push_to_container(&mut container, 1); -// assert!(!container.is_empty()); -// } -// -// #[test] -// fn test_priority_queue_capacity() { -// let mut container = TransactionStateContainer::with_capacity(1); -// push_to_container(&mut container, 5); -// -// assert_eq!(container.priority_queue.len(), 1); -// assert_eq!(container.id_to_transaction_state.len(), 1); -// assert_eq!( -// container -// .id_to_transaction_state -// .iter() -// .map(|ts| ts.1.priority()) -// .next() -// .unwrap(), -// 4 -// ); -// } -// -// #[test] -// fn test_get_mut_transaction_state() { -// let mut container = TransactionStateContainer::with_capacity(5); -// push_to_container(&mut container, 5); -// -// let existing_id = TransactionId::new(3); -// let non_existing_id = TransactionId::new(7); -// assert!(container.get_mut_transaction_state(&existing_id).is_some()); -// assert!(container.get_mut_transaction_state(&existing_id).is_some()); -// assert!(container -// .get_mut_transaction_state(&non_existing_id) -// .is_none()); -// } -// } +use { + super::{ + transaction_priority_id::TransactionPriorityId, + transaction_state::{SanitizedTransactionTTL, TransactionState}, + }, + crate::scheduler_messages::TransactionId, + itertools::MinMaxResult, + min_max_heap::MinMaxHeap, + std::{collections::HashMap, sync::Arc}, +}; + +/// This structure will hold `TransactionState` for the entirety of a +/// transaction's lifetime in the scheduler and BankingStage as a whole. +/// +/// Transaction Lifetime: +/// 1. Received from `SigVerify` by `BankingStage` +/// 2. Inserted into `TransactionStateContainer` by `BankingStage` +/// 3. Popped in priority-order by scheduler, and transitioned to `Pending` state +/// 4. Processed by `ConsumeWorker` +/// a. If consumed, remove `Pending` state from the `TransactionStateContainer` +/// b. If retryable, transition back to `Unprocessed` state. +/// Re-insert to the queue, and return to step 3. +/// +/// The structure is composed of two main components: +/// 1. A priority queue of wrapped `TransactionId`s, which are used to +/// order transactions by priority for selection by the scheduler. +/// 2. A map of `TransactionId` to `TransactionState`, which is used to +/// track the state of each transaction. +/// +/// When `Pending`, the associated `TransactionId` is not in the queue, but +/// is still in the map. +/// The entry in the map should exist before insertion into the queue, and be +/// be removed only after the id is removed from the queue. +/// +/// The container maintains a fixed capacity. If the queue is full when pushing +/// a new transaction, the lowest priority transaction will be dropped. +pub struct TransactionStateContainer { + priority_queue: MinMaxHeap, + id_to_transaction_state: HashMap, +} + +impl TransactionStateContainer { + pub fn with_capacity(capacity: usize) -> Self { + Self { + priority_queue: MinMaxHeap::with_capacity(capacity), + id_to_transaction_state: HashMap::with_capacity(capacity), + } + } + + /// Returns true if the queue is empty. + pub fn is_empty(&self) -> bool { + self.priority_queue.is_empty() + } + + /// Returns the remaining capacity of the queue + pub fn remaining_queue_capacity(&self) -> usize { + self.priority_queue.capacity() - self.priority_queue.len() + } + + /// Get the top transaction id in the priority queue. + pub fn pop(&mut self) -> Option { + self.priority_queue.pop_max() + } + + /// Get mutable transaction state by id. + pub fn get_mut_transaction_state( + &mut self, + id: &TransactionId, + ) -> Option<&mut TransactionState> { + self.id_to_transaction_state.get_mut(id) + } + + /// Get reference to `SanitizedTransactionTTL` by id. + /// Panics if the transaction does not exist. + pub fn get_transaction_ttl(&self, id: &TransactionId) -> Option<&SanitizedTransactionTTL> { + self.id_to_transaction_state + .get(id) + .map(|state| state.transaction_ttl()) + } + + /// Insert a new transaction into the container's queues and maps. + /// Returns `true` if a packet was dropped due to capacity limits. + pub fn insert_new_transaction( + &mut self, + transaction_id: TransactionId, + transaction_ttl: SanitizedTransactionTTL, + priority: u64, + cost: u64, + ) -> bool { + let priority_id = TransactionPriorityId::new(priority, transaction_id); + self.id_to_transaction_state.insert( + transaction_id, + TransactionState::new(transaction_ttl, priority, cost), + ); + self.push_id_into_queue(priority_id) + } + + /// Retries a transaction - inserts transaction back into map (but not packet). + /// This transitions the transaction to `Unprocessed` state. + pub fn retry_transaction( + &mut self, + transaction_id: TransactionId, + transaction_ttl: SanitizedTransactionTTL, + ) { + let transaction_state = self + .get_mut_transaction_state(&transaction_id) + .expect("transaction must exist"); + let priority_id = TransactionPriorityId::new(transaction_state.priority(), transaction_id); + transaction_state.transition_to_unprocessed(transaction_ttl); + self.push_id_into_queue(priority_id); + } + + /// Pushes a transaction id into the priority queue. If the queue is full, the lowest priority + /// transaction will be dropped (removed from the queue and map). + /// Returns `true` if a packet was dropped due to capacity limits. + pub fn push_id_into_queue(&mut self, priority_id: TransactionPriorityId) -> bool { + if self.remaining_queue_capacity() == 0 { + let popped_id = self.priority_queue.push_pop_min(priority_id); + self.remove_by_id(&popped_id.id); + true + } else { + self.priority_queue.push(priority_id); + false + } + } + + /// Remove transaction by id. + pub fn remove_by_id(&mut self, id: &TransactionId) { + self.id_to_transaction_state + .remove(id) + .expect("transaction must exist"); + } + + pub fn get_min_max_priority(&self) -> MinMaxResult { + match self.priority_queue.peek_min() { + Some(min) => match self.priority_queue.peek_max() { + Some(max) => MinMaxResult::MinMax(min.priority, max.priority), + None => MinMaxResult::OneElement(min.priority), + }, + None => MinMaxResult::NoElements, + } + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::scheduler_messages::MaxAge, + solana_sdk::{ + compute_budget::ComputeBudgetInstruction, + hash::Hash, + message::Message, + signature::Keypair, + signer::Signer, + slot_history::Slot, + system_instruction, + transaction::{SanitizedTransaction, Transaction}, + }, + }; + + /// Returns (transaction_ttl, priority, cost) + fn test_transaction( + priority: u64, + ) -> ( + SanitizedTransactionTTL, + u64, + u64, + ) { + let from_keypair = Keypair::new(); + let ixs = vec![ + system_instruction::transfer( + &from_keypair.pubkey(), + &solana_sdk::pubkey::new_rand(), + 1, + ), + ComputeBudgetInstruction::set_compute_unit_price(priority), + ]; + let message = Message::new(&ixs, Some(&from_keypair.pubkey())); + let tx = SanitizedTransaction::from_transaction_for_tests(Transaction::new( + &[&from_keypair], + message, + Hash::default(), + )); + let transaction_ttl = SanitizedTransactionTTL { + transaction: tx, + max_age: MaxAge { + epoch_invalidation_slot: Slot::MAX, + alt_invalidation_slot: Slot::MAX, + }, + }; + const TEST_TRANSACTION_COST: u64 = 5000; + (transaction_ttl, priority, TEST_TRANSACTION_COST) + } + + fn push_to_container( + container: &mut TransactionStateContainer, + num: usize, + ) { + for id in 0..num as u64 { + let priority = id; + let (transaction_ttl, priority, cost) = test_transaction(priority); + container.insert_new_transaction( + TransactionId::new(id), + transaction_ttl, + priority, + cost, + ); + } + } + + #[test] + fn test_is_empty() { + let mut container = TransactionStateContainer::with_capacity(1); + assert!(container.is_empty()); + + push_to_container(&mut container, 1); + assert!(!container.is_empty()); + } + + #[test] + fn test_priority_queue_capacity() { + let mut container = TransactionStateContainer::with_capacity(1); + push_to_container(&mut container, 5); + + assert_eq!(container.priority_queue.len(), 1); + assert_eq!(container.id_to_transaction_state.len(), 1); + assert_eq!( + container + .id_to_transaction_state + .iter() + .map(|ts| ts.1.priority()) + .next() + .unwrap(), + 4 + ); + } + + #[test] + fn test_get_mut_transaction_state() { + let mut container = TransactionStateContainer::with_capacity(5); + push_to_container(&mut container, 5); + + let existing_id = TransactionId::new(3); + let non_existing_id = TransactionId::new(7); + assert!(container.get_mut_transaction_state(&existing_id).is_some()); + assert!(container.get_mut_transaction_state(&existing_id).is_some()); + assert!(container + .get_mut_transaction_state(&non_existing_id) + .is_none()); + } +} From a67511ad1d846e0daab866a3d44765638c35f32a Mon Sep 17 00:00:00 2001 From: lewis Date: Wed, 30 Oct 2024 11:19:12 +0800 Subject: [PATCH 4/4] feat: pg-scheduler with some bugfix --- Cargo.lock | 30 - Cargo.toml | 1 - prio-graph-scheduler/Cargo.toml | 50 - .../src/deserializable_packet.rs | 23 - prio-graph-scheduler/src/id_generator.rs | 19 - prio-graph-scheduler/src/in_flight_tracker.rs | 122 --- prio-graph-scheduler/src/lib.rs | 131 --- .../src/read_write_account_set.rs | 287 ------ prio-graph-scheduler/src/scheduler_error.rs | 9 - .../src/scheduler_messages.rs | 70 -- prio-graph-scheduler/src/scheduler_metrics.rs | 408 -------- .../src/thread_aware_account_locks.rs | 742 -------------- .../src/transaction_priority_id.rs | 69 -- prio-graph-scheduler/src/transaction_state.rs | 384 -------- .../src/transaction_state_container.rs | 263 ----- rpc/src/lib.rs | 1 + scheduler/bin/scheduling_simulation.rs | 24 +- scheduler/src/id_generator.rs | 2 +- scheduler/src/impls/no_lock_scheduler/mod.rs | 9 +- .../prio_graph_scheduler/in_flight_tracker.rs | 2 +- .../src/impls/prio_graph_scheduler/mod.rs | 54 +- .../prio_graph_scheduler.rs | 910 ------------------ .../read_write_account_set.rs | 2 +- .../impls/prio_graph_scheduler/scheduler.rs | 98 +- .../thread_aware_account_locks.rs | 24 +- .../prio_graph_scheduler/transaction_state.rs | 8 +- .../transaction_state_container.rs | 19 +- scheduler/src/lib.rs | 2 +- scheduler/src/scheduler.rs | 8 +- scheduler/src/scheduler_messages.rs | 12 +- 30 files changed, 150 insertions(+), 3633 deletions(-) delete mode 100644 prio-graph-scheduler/Cargo.toml delete mode 100644 prio-graph-scheduler/src/deserializable_packet.rs delete mode 100644 prio-graph-scheduler/src/id_generator.rs delete mode 100644 prio-graph-scheduler/src/in_flight_tracker.rs delete mode 100644 prio-graph-scheduler/src/lib.rs delete mode 100644 prio-graph-scheduler/src/read_write_account_set.rs delete mode 100644 prio-graph-scheduler/src/scheduler_error.rs delete mode 100644 prio-graph-scheduler/src/scheduler_messages.rs delete mode 100644 prio-graph-scheduler/src/scheduler_metrics.rs delete mode 100644 prio-graph-scheduler/src/thread_aware_account_locks.rs delete mode 100644 prio-graph-scheduler/src/transaction_priority_id.rs delete mode 100644 prio-graph-scheduler/src/transaction_state.rs delete mode 100644 prio-graph-scheduler/src/transaction_state_container.rs delete mode 100644 scheduler/src/impls/prio_graph_scheduler/prio_graph_scheduler.rs rename prio-graph-scheduler/src/prio_graph_scheduler.rs => scheduler/src/impls/prio_graph_scheduler/scheduler.rs (92%) diff --git a/Cargo.lock b/Cargo.lock index 71142bd..f15e4da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5375,36 +5375,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "solana-prio-graph-scheduler" -version = "0.1.0" -dependencies = [ - "ahash 0.8.11", - "arrayvec", - "assert_matches", - "bincode", - "crossbeam-channel", - "itertools 0.13.0", - "log", - "min-max-heap", - "prio-graph", - "solana-compute-budget", - "solana-cost-model", - "solana-gossip", - "solana-ledger", - "solana-measure", - "solana-metrics", - "solana-perf", - "solana-poh", - "solana-prio-graph-scheduler", - "solana-runtime", - "solana-runtime-transaction", - "solana-sanitize", - "solana-sdk", - "solana-short-vec", - "thiserror", -] - [[package]] name = "solana-program" version = "2.0.13" diff --git a/Cargo.toml b/Cargo.toml index b3e2b9e..32db3e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,6 @@ members = [ "svm/cli", "svm/executor", "verifier", - "prio-graph-scheduler", "rpc", "scheduler", ] diff --git a/prio-graph-scheduler/Cargo.toml b/prio-graph-scheduler/Cargo.toml deleted file mode 100644 index 8015547..0000000 --- a/prio-graph-scheduler/Cargo.toml +++ /dev/null @@ -1,50 +0,0 @@ -[package] -name = "solana-prio-graph-scheduler" -description = "Solana Priority Graph Scheduler" -documentation = "https://docs.rs/solana-prio-graph-scheduler" -version.workspace = true -authors.workspace = true -repository.workspace = true -homepage.workspace = true -license.workspace = true -edition.workspace = true - -[dependencies] -solana-sdk = { workspace = true } -solana-poh = { workspace = true } -solana-metrics = { workspace = true } -solana-ledger = { workspace = true } -solana-runtime = { workspace = true } -solana-gossip = { workspace = true } -solana-cost-model = { workspace = true } -solana-measure = { workspace = true } - -ahash = { workspace = true } -prio-graph = { workspace = true } -thiserror = { workspace = true } -itertools = { workspace = true } -log = { workspace = true } -crossbeam-channel = { workspace = true } -arrayvec = { workspace = true } -min-max-heap = { workspace = true } - -[dev-dependencies] -assert_matches = { workspace = true } -solana-compute-budget = { workspace = true } -solana-perf = { workspace = true } -solana-runtime-transaction = { workspace = true } -solana-sanitize = { workspace = true } -solana-short-vec = { workspace = true } -# let dev-context-only-utils works when running this crate. -solana-prio-graph-scheduler = { path = ".", features = [ - "dev-context-only-utils", -] } -solana-sdk = { workspace = true, features = ["dev-context-only-utils"] } - -bincode = { workspace = true } - -[package.metadata.docs.rs] -targets = ["x86_64-unknown-linux-gnu"] - -[features] -dev-context-only-utils = ["solana-runtime/dev-context-only-utils"] diff --git a/prio-graph-scheduler/src/deserializable_packet.rs b/prio-graph-scheduler/src/deserializable_packet.rs deleted file mode 100644 index 71a04e3..0000000 --- a/prio-graph-scheduler/src/deserializable_packet.rs +++ /dev/null @@ -1,23 +0,0 @@ -use solana_sdk::hash::Hash; -use solana_sdk::packet::Packet; -use solana_sdk::transaction::SanitizedVersionedTransaction; -use std::error::Error; - -/// DeserializablePacket can be deserialized from a Packet. -/// -/// DeserializablePacket will be deserialized as a SanitizedTransaction -/// to be scheduled in transaction stream and scheduler. -pub trait DeserializableTxPacket: PartialEq + PartialOrd + Eq + Sized { - type DeserializeError: Error; - - fn new(packet: Packet) -> Result; - - fn original_packet(&self) -> &Packet; - - /// deserialized into versionedTx, and then to SanitizedTransaction. - fn transaction(&self) -> &SanitizedVersionedTransaction; - - fn message_hash(&self) -> &Hash; - - fn is_simple_vote(&self) -> bool; -} diff --git a/prio-graph-scheduler/src/id_generator.rs b/prio-graph-scheduler/src/id_generator.rs deleted file mode 100644 index 3090e4e..0000000 --- a/prio-graph-scheduler/src/id_generator.rs +++ /dev/null @@ -1,19 +0,0 @@ -/// Simple reverse-sequential ID generator for `TransactionId`s. -/// These IDs uniquely identify transactions during the scheduling process. -pub struct IdGenerator { - next_id: u64, -} - -impl Default for IdGenerator { - fn default() -> Self { - Self { next_id: u64::MAX } - } -} - -impl IdGenerator { - pub fn next>(&mut self) -> T { - let id = self.next_id; - self.next_id = self.next_id.wrapping_sub(1); - T::from(id) - } -} diff --git a/prio-graph-scheduler/src/in_flight_tracker.rs b/prio-graph-scheduler/src/in_flight_tracker.rs deleted file mode 100644 index 4e650a3..0000000 --- a/prio-graph-scheduler/src/in_flight_tracker.rs +++ /dev/null @@ -1,122 +0,0 @@ -use { - crate::id_generator::IdGenerator, crate::scheduler_messages::TransactionBatchId, - crate::thread_aware_account_locks::ThreadId, std::collections::HashMap, -}; - -/// Tracks the number of transactions that are in flight for each thread. -pub struct InFlightTracker { - num_in_flight_per_thread: Vec, - cus_in_flight_per_thread: Vec, - batches: HashMap, - batch_id_generator: IdGenerator, -} - -struct BatchEntry { - thread_id: ThreadId, - num_transactions: usize, - total_cus: u64, -} - -impl InFlightTracker { - pub fn new(num_threads: usize) -> Self { - Self { - num_in_flight_per_thread: vec![0; num_threads], - cus_in_flight_per_thread: vec![0; num_threads], - batches: HashMap::new(), - batch_id_generator: IdGenerator::default(), - } - } - - /// Returns the number of transactions that are in flight for each thread. - pub fn num_in_flight_per_thread(&self) -> &[usize] { - &self.num_in_flight_per_thread - } - - /// Returns the number of cus that are in flight for each thread. - pub fn cus_in_flight_per_thread(&self) -> &[u64] { - &self.cus_in_flight_per_thread - } - - /// Tracks number of transactions and CUs in-flight for the `thread_id`. - /// Returns a `TransactionBatchId` that can be used to stop tracking the batch - /// when it is complete. - pub fn track_batch( - &mut self, - num_transactions: usize, - total_cus: u64, - thread_id: ThreadId, - ) -> TransactionBatchId { - let batch_id = self.batch_id_generator.next(); - self.num_in_flight_per_thread[thread_id] += num_transactions; - self.cus_in_flight_per_thread[thread_id] += total_cus; - self.batches.insert( - batch_id, - BatchEntry { - thread_id, - num_transactions, - total_cus, - }, - ); - - batch_id - } - - /// Stop tracking the batch with given `batch_id`. - /// Removes the number of transactions for the scheduled thread. - /// Returns the thread id that the batch was scheduled on. - /// - /// # Panics - /// Panics if the batch id does not exist in the tracker. - pub fn complete_batch(&mut self, batch_id: TransactionBatchId) -> ThreadId { - let Some(BatchEntry { - thread_id, - num_transactions, - total_cus, - }) = self.batches.remove(&batch_id) - else { - panic!("batch id {batch_id} is not being tracked"); - }; - self.num_in_flight_per_thread[thread_id] -= num_transactions; - self.cus_in_flight_per_thread[thread_id] -= total_cus; - - thread_id - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - #[should_panic(expected = "is not being tracked")] - fn test_in_flight_tracker_untracked_batch() { - let mut in_flight_tracker = InFlightTracker::new(2); - in_flight_tracker.complete_batch(TransactionBatchId::new(5)); - } - - #[test] - fn test_in_flight_tracker() { - let mut in_flight_tracker = InFlightTracker::new(2); - - // Add a batch with 2 transactions, 10 kCUs to thread 0. - let batch_id_0 = in_flight_tracker.track_batch(2, 10_000, 0); - assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[2, 0]); - assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[10_000, 0]); - - // Add a batch with 1 transaction, 15 kCUs to thread 1. - let batch_id_1 = in_flight_tracker.track_batch(1, 15_000, 1); - assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[2, 1]); - assert_eq!( - in_flight_tracker.cus_in_flight_per_thread(), - &[10_000, 15_000] - ); - - in_flight_tracker.complete_batch(batch_id_0); - assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[0, 1]); - assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[0, 15_000]); - - in_flight_tracker.complete_batch(batch_id_1); - assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[0, 0]); - assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[0, 0]); - } -} diff --git a/prio-graph-scheduler/src/lib.rs b/prio-graph-scheduler/src/lib.rs deleted file mode 100644 index 6c28df5..0000000 --- a/prio-graph-scheduler/src/lib.rs +++ /dev/null @@ -1,131 +0,0 @@ -//! Solana Priority Graph Scheduler. -pub mod id_generator; -pub mod in_flight_tracker; -pub mod scheduler_error; -pub mod scheduler_messages; -pub mod scheduler_metrics; -pub mod thread_aware_account_locks; -pub mod transaction_priority_id; -pub mod transaction_state; -// pub mod scheduler_controller; -pub mod deserializable_packet; -pub mod prio_graph_scheduler; -pub mod transaction_state_container; - -#[macro_use] -extern crate solana_metrics; - -#[cfg(test)] -#[macro_use] -extern crate assert_matches; - -/// Consumer will create chunks of transactions from buffer with up to this size. -pub const TARGET_NUM_TRANSACTIONS_PER_BATCH: usize = 64; - -mod read_write_account_set; - -#[cfg(test)] -mod tests { - use { - crate::deserializable_packet::DeserializableTxPacket, - solana_perf::packet::Packet, - solana_sdk::{ - hash::Hash, - message::Message, - sanitize::SanitizeError, - signature::Signature, - transaction::{SanitizedVersionedTransaction, VersionedTransaction}, - }, - solana_short_vec::decode_shortu16_len, - std::{cmp::Ordering, mem::size_of}, - thiserror::Error, - }; - - #[derive(Debug, Error)] - pub enum MockDeserializedPacketError { - #[error("ShortVec Failed to Deserialize")] - // short_vec::decode_shortu16_len() currently returns () on error - ShortVecError(()), - #[error("Deserialization Error: {0}")] - DeserializationError(#[from] bincode::Error), - #[error("overflowed on signature size {0}")] - SignatureOverflowed(usize), - #[error("packet failed sanitization {0}")] - SanitizeError(#[from] SanitizeError), - } - - #[derive(Debug, Eq)] - pub struct MockImmutableDeserializedPacket { - pub original_packet: Packet, - pub transaction: SanitizedVersionedTransaction, - pub message_hash: Hash, - pub is_simple_vote: bool, - } - - impl DeserializableTxPacket for MockImmutableDeserializedPacket { - type DeserializeError = MockDeserializedPacketError; - fn new(packet: Packet) -> Result { - let versioned_transaction: VersionedTransaction = packet.deserialize_slice(..)?; - let sanitized_transaction = - SanitizedVersionedTransaction::try_from(versioned_transaction)?; - let message_bytes = packet_message(&packet)?; - let message_hash = Message::hash_raw_message(message_bytes); - let is_simple_vote = packet.meta().is_simple_vote_tx(); - - Ok(Self { - original_packet: packet, - transaction: sanitized_transaction, - message_hash, - is_simple_vote, - }) - } - - fn original_packet(&self) -> &Packet { - &self.original_packet - } - - fn transaction(&self) -> &SanitizedVersionedTransaction { - &self.transaction - } - - fn message_hash(&self) -> &Hash { - &self.message_hash - } - - fn is_simple_vote(&self) -> bool { - self.is_simple_vote - } - } - - // PartialEq MUST be consistent with PartialOrd and Ord - impl PartialEq for MockImmutableDeserializedPacket { - fn eq(&self, other: &Self) -> bool { - self.message_hash == other.message_hash - } - } - - impl PartialOrd for MockImmutableDeserializedPacket { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } - } - - impl Ord for MockImmutableDeserializedPacket { - fn cmp(&self, other: &Self) -> Ordering { - self.message_hash().cmp(other.message_hash()) - } - } - - /// Read the transaction message from packet data - fn packet_message(packet: &Packet) -> Result<&[u8], MockDeserializedPacketError> { - let (sig_len, sig_size) = packet - .data(..) - .and_then(|bytes| decode_shortu16_len(bytes).ok()) - .ok_or(MockDeserializedPacketError::ShortVecError(()))?; - sig_len - .checked_mul(size_of::()) - .and_then(|v| v.checked_add(sig_size)) - .and_then(|msg_start| packet.data(msg_start..)) - .ok_or(MockDeserializedPacketError::SignatureOverflowed(sig_size)) - } -} diff --git a/prio-graph-scheduler/src/read_write_account_set.rs b/prio-graph-scheduler/src/read_write_account_set.rs deleted file mode 100644 index 0e70837..0000000 --- a/prio-graph-scheduler/src/read_write_account_set.rs +++ /dev/null @@ -1,287 +0,0 @@ -use { - ahash::AHashSet, - solana_sdk::{message::SanitizedMessage, pubkey::Pubkey}, -}; - -/// Wrapper struct to accumulate locks for a batch of transactions. -#[derive(Debug, Default)] -pub struct ReadWriteAccountSet { - /// Set of accounts that are locked for read - read_set: AHashSet, - /// Set of accounts that are locked for write - write_set: AHashSet, -} - -impl ReadWriteAccountSet { - /// Returns true if all account locks were available and false otherwise. - pub fn check_locks(&self, message: &SanitizedMessage) -> bool { - message - .account_keys() - .iter() - .enumerate() - .all(|(index, pubkey)| { - if message.is_writable(index) { - self.can_write(pubkey) - } else { - self.can_read(pubkey) - } - }) - } - - /// Add all account locks. - /// Returns true if all account locks were available and false otherwise. - pub fn take_locks(&mut self, message: &SanitizedMessage) -> bool { - message - .account_keys() - .iter() - .enumerate() - .fold(true, |all_available, (index, pubkey)| { - if message.is_writable(index) { - all_available & self.add_write(pubkey) - } else { - all_available & self.add_read(pubkey) - } - }) - } - - /// Clears the read and write sets - #[allow(dead_code)] - pub fn clear(&mut self) { - self.read_set.clear(); - self.write_set.clear(); - } - - /// Check if an account can be read-locked - fn can_read(&self, pubkey: &Pubkey) -> bool { - !self.write_set.contains(pubkey) - } - - /// Check if an account can be write-locked - fn can_write(&self, pubkey: &Pubkey) -> bool { - !self.write_set.contains(pubkey) && !self.read_set.contains(pubkey) - } - - /// Add an account to the read-set. - /// Returns true if the lock was available. - fn add_read(&mut self, pubkey: &Pubkey) -> bool { - let can_read = self.can_read(pubkey); - self.read_set.insert(*pubkey); - - can_read - } - - /// Add an account to the write-set. - /// Returns true if the lock was available. - fn add_write(&mut self, pubkey: &Pubkey) -> bool { - let can_write = self.can_write(pubkey); - self.write_set.insert(*pubkey); - - can_write - } -} - -#[cfg(test)] -mod tests { - use { - super::ReadWriteAccountSet, - solana_ledger::genesis_utils::GenesisConfigInfo, - solana_runtime::{bank::Bank, bank_forks::BankForks, genesis_utils::create_genesis_config}, - solana_sdk::{ - account::AccountSharedData, - address_lookup_table::{ - self, - state::{AddressLookupTable, LookupTableMeta}, - }, - hash::Hash, - message::{ - v0::{self, MessageAddressTableLookup}, - MessageHeader, VersionedMessage, - }, - pubkey::Pubkey, - signature::Keypair, - signer::Signer, - transaction::{MessageHash, SanitizedTransaction, VersionedTransaction}, - }, - std::{ - borrow::Cow, - sync::{Arc, RwLock}, - }, - }; - - fn create_test_versioned_message( - write_keys: &[Pubkey], - read_keys: &[Pubkey], - address_table_lookups: Vec, - ) -> VersionedMessage { - VersionedMessage::V0(v0::Message { - header: MessageHeader { - num_required_signatures: write_keys.len() as u8, - num_readonly_signed_accounts: 0, - num_readonly_unsigned_accounts: read_keys.len() as u8, - }, - recent_blockhash: Hash::default(), - account_keys: write_keys.iter().chain(read_keys.iter()).copied().collect(), - address_table_lookups, - instructions: vec![], - }) - } - - fn create_test_sanitized_transaction( - write_keypair: &Keypair, - read_keys: &[Pubkey], - address_table_lookups: Vec, - bank: &Bank, - ) -> SanitizedTransaction { - let message = create_test_versioned_message( - &[write_keypair.pubkey()], - read_keys, - address_table_lookups, - ); - SanitizedTransaction::try_create( - VersionedTransaction::try_new(message, &[write_keypair]).unwrap(), - MessageHash::Compute, - Some(false), - bank, - bank.get_reserved_account_keys(), - ) - .unwrap() - } - - fn create_test_address_lookup_table( - bank: Arc, - num_addresses: usize, - ) -> (Arc, Pubkey) { - let mut addresses = Vec::with_capacity(num_addresses); - addresses.resize_with(num_addresses, Pubkey::new_unique); - let address_lookup_table = AddressLookupTable { - meta: LookupTableMeta { - authority: None, - ..LookupTableMeta::default() - }, - addresses: Cow::Owned(addresses), - }; - - let address_table_key = Pubkey::new_unique(); - let data = address_lookup_table.serialize_for_tests().unwrap(); - let mut account = - AccountSharedData::new(1, data.len(), &address_lookup_table::program::id()); - account.set_data(data); - bank.store_account(&address_table_key, &account); - - let slot = bank.slot() + 1; - ( - Arc::new(Bank::new_from_parent(bank, &Pubkey::new_unique(), slot)), - address_table_key, - ) - } - - fn create_test_bank() -> (Arc, Arc>) { - let GenesisConfigInfo { genesis_config, .. } = create_genesis_config(10_000); - Bank::new_no_wallclock_throttle_for_tests(&genesis_config) - } - - // Helper function (could potentially use test_case in future). - // conflict_index = 0 means write lock conflict with static key - // conflict_index = 1 means read lock conflict with static key - // conflict_index = 2 means write lock conflict with address table key - // conflict_index = 3 means read lock conflict with address table key - fn test_check_and_take_locks(conflict_index: usize, add_write: bool, expectation: bool) { - let (bank, _bank_forks) = create_test_bank(); - let (bank, table_address) = create_test_address_lookup_table(bank, 2); - let tx = create_test_sanitized_transaction( - &Keypair::new(), - &[Pubkey::new_unique()], - vec![MessageAddressTableLookup { - account_key: table_address, - writable_indexes: vec![0], - readonly_indexes: vec![1], - }], - &bank, - ); - let message = tx.message(); - - let mut account_locks = ReadWriteAccountSet::default(); - - let conflict_key = message.account_keys().get(conflict_index).unwrap(); - if add_write { - account_locks.add_write(conflict_key); - } else { - account_locks.add_read(conflict_key); - } - assert_eq!(expectation, account_locks.check_locks(message)); - assert_eq!(expectation, account_locks.take_locks(message)); - } - - #[test] - fn test_check_and_take_locks_write_write_conflict() { - test_check_and_take_locks(0, true, false); // static key conflict - test_check_and_take_locks(2, true, false); // lookup key conflict - } - - #[test] - fn test_check_and_take_locks_read_write_conflict() { - test_check_and_take_locks(0, false, false); // static key conflict - test_check_and_take_locks(2, false, false); // lookup key conflict - } - - #[test] - fn test_check_and_take_locks_write_read_conflict() { - test_check_and_take_locks(1, true, false); // static key conflict - test_check_and_take_locks(3, true, false); // lookup key conflict - } - - #[test] - fn test_check_and_take_locks_read_read_non_conflict() { - test_check_and_take_locks(1, false, true); // static key conflict - test_check_and_take_locks(3, false, true); // lookup key conflict - } - - #[test] - pub fn test_write_write_conflict() { - let mut account_locks = ReadWriteAccountSet::default(); - let account = Pubkey::new_unique(); - assert!(account_locks.can_write(&account)); - account_locks.add_write(&account); - assert!(!account_locks.can_write(&account)); - } - - #[test] - pub fn test_read_write_conflict() { - let mut account_locks = ReadWriteAccountSet::default(); - let account = Pubkey::new_unique(); - assert!(account_locks.can_read(&account)); - account_locks.add_read(&account); - assert!(!account_locks.can_write(&account)); - assert!(account_locks.can_read(&account)); - } - - #[test] - pub fn test_write_read_conflict() { - let mut account_locks = ReadWriteAccountSet::default(); - let account = Pubkey::new_unique(); - assert!(account_locks.can_write(&account)); - account_locks.add_write(&account); - assert!(!account_locks.can_write(&account)); - assert!(!account_locks.can_read(&account)); - } - - #[test] - pub fn test_read_read_non_conflict() { - let mut account_locks = ReadWriteAccountSet::default(); - let account = Pubkey::new_unique(); - assert!(account_locks.can_read(&account)); - account_locks.add_read(&account); - assert!(account_locks.can_read(&account)); - } - - #[test] - pub fn test_write_write_different_keys() { - let mut account_locks = ReadWriteAccountSet::default(); - let account1 = Pubkey::new_unique(); - let account2 = Pubkey::new_unique(); - assert!(account_locks.can_write(&account1)); - account_locks.add_write(&account1); - assert!(account_locks.can_write(&account2)); - assert!(account_locks.can_read(&account2)); - } -} diff --git a/prio-graph-scheduler/src/scheduler_error.rs b/prio-graph-scheduler/src/scheduler_error.rs deleted file mode 100644 index 9b8d401..0000000 --- a/prio-graph-scheduler/src/scheduler_error.rs +++ /dev/null @@ -1,9 +0,0 @@ -use thiserror::Error; - -#[derive(Debug, Error)] -pub enum SchedulerError { - #[error("Sending channel disconnected: {0}")] - DisconnectedSendChannel(&'static str), - #[error("Recv channel disconnected: {0}")] - DisconnectedRecvChannel(&'static str), -} diff --git a/prio-graph-scheduler/src/scheduler_messages.rs b/prio-graph-scheduler/src/scheduler_messages.rs deleted file mode 100644 index b7ecf19..0000000 --- a/prio-graph-scheduler/src/scheduler_messages.rs +++ /dev/null @@ -1,70 +0,0 @@ -use { - solana_sdk::{clock::Slot, transaction::SanitizedTransaction}, - std::fmt::Display, -}; - -/// A unique identifier for a transaction batch. -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] -pub struct TransactionBatchId(u64); - -impl TransactionBatchId { - pub fn new(index: u64) -> Self { - Self(index) - } -} - -impl Display for TransactionBatchId { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From for TransactionBatchId { - fn from(id: u64) -> Self { - Self(id) - } -} - -/// A unique identifier for a transaction. -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] -pub struct TransactionId(u64); - -impl TransactionId { - pub fn new(index: u64) -> Self { - Self(index) - } -} - -impl Display for TransactionId { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From for TransactionId { - fn from(id: u64) -> Self { - Self(id) - } -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct MaxAge { - pub epoch_invalidation_slot: Slot, - pub alt_invalidation_slot: Slot, -} - -/// Message: [Scheduler -> Worker] -/// Transactions to be consumed (i.e. executed, recorded, and committed) -pub struct ConsumeWork { - pub batch_id: TransactionBatchId, - pub ids: Vec, - pub transactions: Vec, - pub max_ages: Vec, -} - -/// Message: [Worker -> Scheduler] -/// Processed transactions. -pub struct FinishedConsumeWork { - pub work: ConsumeWork, - pub retryable_indexes: Vec, -} diff --git a/prio-graph-scheduler/src/scheduler_metrics.rs b/prio-graph-scheduler/src/scheduler_metrics.rs deleted file mode 100644 index bb8cbbe..0000000 --- a/prio-graph-scheduler/src/scheduler_metrics.rs +++ /dev/null @@ -1,408 +0,0 @@ -use { - itertools::MinMaxResult, - solana_poh::poh_recorder::BankStart, - solana_sdk::{clock::Slot, timing::AtomicInterval}, - std::time::Instant, -}; - -#[derive(Default)] -pub struct SchedulerCountMetrics { - interval: IntervalSchedulerCountMetrics, - slot: SlotSchedulerCountMetrics, -} - -impl SchedulerCountMetrics { - pub fn update(&mut self, update: impl Fn(&mut SchedulerCountMetricsInner)) { - update(&mut self.interval.metrics); - update(&mut self.slot.metrics); - } - - pub fn maybe_report_and_reset_slot(&mut self, slot: Option) { - self.slot.maybe_report_and_reset(slot); - } - - pub fn maybe_report_and_reset_interval(&mut self, should_report: bool) { - self.interval.maybe_report_and_reset(should_report); - } - - pub fn interval_has_data(&self) -> bool { - self.interval.metrics.has_data() - } -} - -#[derive(Default)] -struct IntervalSchedulerCountMetrics { - interval: AtomicInterval, - metrics: SchedulerCountMetricsInner, -} - -#[derive(Default)] -struct SlotSchedulerCountMetrics { - slot: Option, - metrics: SchedulerCountMetricsInner, -} - -#[derive(Default)] -pub struct SchedulerCountMetricsInner { - /// Number of packets received. - pub num_received: usize, - /// Number of packets buffered. - pub num_buffered: usize, - - /// Number of transactions scheduled. - pub num_scheduled: usize, - /// Number of transactions that were unschedulable. - pub num_unschedulable: usize, - /// Number of transactions that were filtered out during scheduling. - pub num_schedule_filtered_out: usize, - /// Number of completed transactions received from workers. - pub num_finished: usize, - /// Number of transactions that were retryable. - pub num_retryable: usize, - /// Number of transactions that were scheduled to be forwarded. - pub num_forwarded: usize, - - /// Number of transactions that were immediately dropped on receive. - pub num_dropped_on_receive: usize, - /// Number of transactions that were dropped due to sanitization failure. - pub num_dropped_on_sanitization: usize, - /// Number of transactions that were dropped due to failed lock validation. - pub num_dropped_on_validate_locks: usize, - /// Number of transactions that were dropped due to failed transaction - /// checks during receive. - pub num_dropped_on_receive_transaction_checks: usize, - /// Number of transactions that were dropped due to clearing. - pub num_dropped_on_clear: usize, - /// Number of transactions that were dropped due to age and status checks. - pub num_dropped_on_age_and_status: usize, - /// Number of transactions that were dropped due to exceeded capacity. - pub num_dropped_on_capacity: usize, - /// Min prioritization fees in the transaction container - pub min_prioritization_fees: u64, - /// Max prioritization fees in the transaction container - pub max_prioritization_fees: u64, -} - -impl IntervalSchedulerCountMetrics { - fn maybe_report_and_reset(&mut self, should_report: bool) { - const REPORT_INTERVAL_MS: u64 = 1000; - if self.interval.should_update(REPORT_INTERVAL_MS) { - if should_report { - self.metrics.report("banking_stage_scheduler_counts", None); - } - self.metrics.reset(); - } - } -} - -impl SlotSchedulerCountMetrics { - fn maybe_report_and_reset(&mut self, slot: Option) { - if self.slot != slot { - // Only report if there was an assigned slot. - if self.slot.is_some() { - self.metrics - .report("banking_stage_scheduler_slot_counts", self.slot); - } - self.metrics.reset(); - self.slot = slot; - } - } -} - -impl SchedulerCountMetricsInner { - fn report(&self, name: &'static str, slot: Option) { - let mut datapoint = create_datapoint!( - @point name, - ("num_received", self.num_received, i64), - ("num_buffered", self.num_buffered, i64), - ("num_scheduled", self.num_scheduled, i64), - ("num_unschedulable", self.num_unschedulable, i64), - ( - "num_schedule_filtered_out", - self.num_schedule_filtered_out, - i64 - ), - ("num_finished", self.num_finished, i64), - ("num_retryable", self.num_retryable, i64), - ("num_forwarded", self.num_forwarded, i64), - ("num_dropped_on_receive", self.num_dropped_on_receive, i64), - ( - "num_dropped_on_sanitization", - self.num_dropped_on_sanitization, - i64 - ), - ( - "num_dropped_on_validate_locks", - self.num_dropped_on_validate_locks, - i64 - ), - ( - "num_dropped_on_receive_transaction_checks", - self.num_dropped_on_receive_transaction_checks, - i64 - ), - ("num_dropped_on_clear", self.num_dropped_on_clear, i64), - ( - "num_dropped_on_age_and_status", - self.num_dropped_on_age_and_status, - i64 - ), - ("num_dropped_on_capacity", self.num_dropped_on_capacity, i64), - ("min_priority", self.get_min_priority(), i64), - ("max_priority", self.get_max_priority(), i64) - ); - if let Some(slot) = slot { - datapoint.add_field_i64("slot", slot as i64); - } - solana_metrics::submit(datapoint, log::Level::Info); - } - - pub fn has_data(&self) -> bool { - self.num_received != 0 - || self.num_buffered != 0 - || self.num_scheduled != 0 - || self.num_unschedulable != 0 - || self.num_schedule_filtered_out != 0 - || self.num_finished != 0 - || self.num_retryable != 0 - || self.num_forwarded != 0 - || self.num_dropped_on_receive != 0 - || self.num_dropped_on_sanitization != 0 - || self.num_dropped_on_validate_locks != 0 - || self.num_dropped_on_receive_transaction_checks != 0 - || self.num_dropped_on_clear != 0 - || self.num_dropped_on_age_and_status != 0 - || self.num_dropped_on_capacity != 0 - } - - fn reset(&mut self) { - self.num_received = 0; - self.num_buffered = 0; - self.num_scheduled = 0; - self.num_unschedulable = 0; - self.num_schedule_filtered_out = 0; - self.num_finished = 0; - self.num_retryable = 0; - self.num_forwarded = 0; - self.num_dropped_on_receive = 0; - self.num_dropped_on_sanitization = 0; - self.num_dropped_on_validate_locks = 0; - self.num_dropped_on_receive_transaction_checks = 0; - self.num_dropped_on_clear = 0; - self.num_dropped_on_age_and_status = 0; - self.num_dropped_on_capacity = 0; - self.min_prioritization_fees = u64::MAX; - self.max_prioritization_fees = 0; - } - - pub fn update_priority_stats(&mut self, min_max_fees: MinMaxResult) { - // update min/max priority - match min_max_fees { - itertools::MinMaxResult::NoElements => { - // do nothing - } - itertools::MinMaxResult::OneElement(e) => { - self.min_prioritization_fees = e; - self.max_prioritization_fees = e; - } - itertools::MinMaxResult::MinMax(min, max) => { - self.min_prioritization_fees = min; - self.max_prioritization_fees = max; - } - } - } - - pub fn get_min_priority(&self) -> u64 { - // to avoid getting u64::max recorded by metrics / in case of edge cases - if self.min_prioritization_fees != u64::MAX { - self.min_prioritization_fees - } else { - 0 - } - } - - pub fn get_max_priority(&self) -> u64 { - self.max_prioritization_fees - } -} - -#[derive(Default)] -pub struct SchedulerTimingMetrics { - interval: IntervalSchedulerTimingMetrics, - slot: SlotSchedulerTimingMetrics, -} - -impl SchedulerTimingMetrics { - pub fn update(&mut self, update: impl Fn(&mut SchedulerTimingMetricsInner)) { - update(&mut self.interval.metrics); - update(&mut self.slot.metrics); - } - - pub fn maybe_report_and_reset_slot(&mut self, slot: Option) { - self.slot.maybe_report_and_reset(slot); - } - - pub fn maybe_report_and_reset_interval(&mut self, should_report: bool) { - self.interval.maybe_report_and_reset(should_report); - } -} - -#[derive(Default)] -struct IntervalSchedulerTimingMetrics { - interval: AtomicInterval, - metrics: SchedulerTimingMetricsInner, -} - -#[derive(Default)] -struct SlotSchedulerTimingMetrics { - slot: Option, - metrics: SchedulerTimingMetricsInner, -} - -#[derive(Default)] -pub struct SchedulerTimingMetricsInner { - /// Time spent making processing decisions. - pub decision_time_us: u64, - /// Time spent receiving packets. - pub receive_time_us: u64, - /// Time spent buffering packets. - pub buffer_time_us: u64, - /// Time spent filtering transactions during scheduling. - pub schedule_filter_time_us: u64, - /// Time spent scheduling transactions. - pub schedule_time_us: u64, - /// Time spent clearing transactions from the container. - pub clear_time_us: u64, - /// Time spent cleaning expired or processed transactions from the container. - pub clean_time_us: u64, - /// Time spent forwarding transactions. - pub forward_time_us: u64, - /// Time spent receiving completed transactions. - pub receive_completed_time_us: u64, -} - -impl IntervalSchedulerTimingMetrics { - fn maybe_report_and_reset(&mut self, should_report: bool) { - const REPORT_INTERVAL_MS: u64 = 1000; - if self.interval.should_update(REPORT_INTERVAL_MS) { - if should_report { - self.metrics.report("banking_stage_scheduler_timing", None); - } - self.metrics.reset(); - } - } -} - -impl SlotSchedulerTimingMetrics { - fn maybe_report_and_reset(&mut self, slot: Option) { - if self.slot != slot { - // Only report if there was an assigned slot. - if self.slot.is_some() { - self.metrics - .report("banking_stage_scheduler_slot_timing", self.slot); - } - self.metrics.reset(); - self.slot = slot; - } - } -} - -impl SchedulerTimingMetricsInner { - fn report(&self, name: &'static str, slot: Option) { - let mut datapoint = create_datapoint!( - @point name, - ("decision_time_us", self.decision_time_us, i64), - ("receive_time_us", self.receive_time_us, i64), - ("buffer_time_us", self.buffer_time_us, i64), - ("schedule_filter_time_us", self.schedule_filter_time_us, i64), - ("schedule_time_us", self.schedule_time_us, i64), - ("clear_time_us", self.clear_time_us, i64), - ("clean_time_us", self.clean_time_us, i64), - ("forward_time_us", self.forward_time_us, i64), - ( - "receive_completed_time_us", - self.receive_completed_time_us, - i64 - ) - ); - if let Some(slot) = slot { - datapoint.add_field_i64("slot", slot as i64); - } - solana_metrics::submit(datapoint, log::Level::Info); - } - - fn reset(&mut self) { - self.decision_time_us = 0; - self.receive_time_us = 0; - self.buffer_time_us = 0; - self.schedule_filter_time_us = 0; - self.schedule_time_us = 0; - self.clear_time_us = 0; - self.clean_time_us = 0; - self.forward_time_us = 0; - self.receive_completed_time_us = 0; - } -} - -#[derive(Default)] -pub struct SchedulerLeaderDetectionMetrics { - inner: Option, -} - -struct SchedulerLeaderDetectionMetricsInner { - slot: Slot, - bank_creation_time: Instant, - bank_detected_time: Instant, -} - -impl SchedulerLeaderDetectionMetrics { - pub fn update_and_maybe_report(&mut self, bank_start: Option<&BankStart>) { - match (&self.inner, bank_start) { - (None, Some(bank_start)) => self.initialize_inner(bank_start), - (Some(_inner), None) => self.report_and_reset(), - (Some(inner), Some(bank_start)) if inner.slot != bank_start.working_bank.slot() => { - self.report_and_reset(); - self.initialize_inner(bank_start); - } - _ => {} - } - } - - fn initialize_inner(&mut self, bank_start: &BankStart) { - let bank_detected_time = Instant::now(); - self.inner = Some(SchedulerLeaderDetectionMetricsInner { - slot: bank_start.working_bank.slot(), - bank_creation_time: *bank_start.bank_creation_time, - bank_detected_time, - }); - } - - fn report_and_reset(&mut self) { - let SchedulerLeaderDetectionMetricsInner { - slot, - bank_creation_time, - bank_detected_time, - } = self.inner.take().expect("inner must be present"); - - let bank_detected_delay_us = bank_detected_time - .duration_since(bank_creation_time) - .as_micros() - .try_into() - .unwrap_or(i64::MAX); - let bank_detected_to_slot_end_detected_us = bank_detected_time - .elapsed() - .as_micros() - .try_into() - .unwrap_or(i64::MAX); - datapoint_info!( - "banking_stage_scheduler_leader_detection", - ("slot", slot, i64), - ("bank_detected_delay_us", bank_detected_delay_us, i64), - ( - "bank_detected_to_slot_end_detected_us", - bank_detected_to_slot_end_detected_us, - i64 - ), - ); - } -} diff --git a/prio-graph-scheduler/src/thread_aware_account_locks.rs b/prio-graph-scheduler/src/thread_aware_account_locks.rs deleted file mode 100644 index 6d9d4c2..0000000 --- a/prio-graph-scheduler/src/thread_aware_account_locks.rs +++ /dev/null @@ -1,742 +0,0 @@ -use { - ahash::AHashMap, - solana_sdk::pubkey::Pubkey, - std::{ - collections::hash_map::Entry, - fmt::{Debug, Display}, - ops::{BitAnd, BitAndAssign, Sub}, - }, -}; - -pub const MAX_THREADS: usize = u64::BITS as usize; - -/// Identifier for a thread -pub type ThreadId = usize; // 0..MAX_THREADS-1 - -type LockCount = u32; - -/// A bit-set of threads an account is scheduled or can be scheduled for. -#[derive(Copy, Clone, PartialEq, Eq)] -pub struct ThreadSet(u64); - -struct AccountWriteLocks { - thread_id: ThreadId, - lock_count: LockCount, -} - -struct AccountReadLocks { - thread_set: ThreadSet, - lock_counts: [LockCount; MAX_THREADS], -} - -/// Account locks. -/// Write Locks - only one thread can hold a write lock at a time. -/// Contains how many write locks are held by the thread. -/// Read Locks - multiple threads can hold a read lock at a time. -/// Contains thread-set for easily checking which threads are scheduled. -#[derive(Default)] -struct AccountLocks { - pub write_locks: Option, - pub read_locks: Option, -} - -/// Thread-aware account locks which allows for scheduling on threads -/// that already hold locks on the account. This is useful for allowing -/// queued transactions to be scheduled on a thread while the transaction -/// is still being executed on the thread. -pub struct ThreadAwareAccountLocks { - /// Number of threads. - num_threads: usize, // 0..MAX_THREADS - /// Locks for each account. An account should only have an entry if there - /// is at least one lock. - locks: AHashMap, -} - -impl ThreadAwareAccountLocks { - /// Creates a new `ThreadAwareAccountLocks` with the given number of threads. - pub fn new(num_threads: usize) -> Self { - assert!(num_threads > 0, "num threads must be > 0"); - assert!( - num_threads <= MAX_THREADS, - "num threads must be <= {MAX_THREADS}" - ); - - Self { - num_threads, - locks: AHashMap::new(), - } - } - - /// Returns the `ThreadId` if the accounts are able to be locked - /// for the given thread, otherwise `None` is returned. - /// `allowed_threads` is a set of threads that the caller restricts locking to. - /// If accounts are schedulable, then they are locked for the thread - /// selected by the `thread_selector` function. - /// `thread_selector` is only called if all accounts are schdulable, meaning - /// that the `thread_set` passed to `thread_selector` is non-empty. - pub fn try_lock_accounts<'a>( - &mut self, - write_account_locks: impl Iterator + Clone, - read_account_locks: impl Iterator + Clone, - allowed_threads: ThreadSet, - thread_selector: impl FnOnce(ThreadSet) -> ThreadId, - ) -> Option { - let schedulable_threads = self.accounts_schedulable_threads( - write_account_locks.clone(), - read_account_locks.clone(), - )? & allowed_threads; - (!schedulable_threads.is_empty()).then(|| { - let thread_id = thread_selector(schedulable_threads); - self.lock_accounts(write_account_locks, read_account_locks, thread_id); - thread_id - }) - } - - /// Unlocks the accounts for the given thread. - pub fn unlock_accounts<'a>( - &mut self, - write_account_locks: impl Iterator, - read_account_locks: impl Iterator, - thread_id: ThreadId, - ) { - for account in write_account_locks { - self.write_unlock_account(account, thread_id); - } - - for account in read_account_locks { - self.read_unlock_account(account, thread_id); - } - } - - /// Returns `ThreadSet` that the given accounts can be scheduled on. - fn accounts_schedulable_threads<'a>( - &self, - write_account_locks: impl Iterator, - read_account_locks: impl Iterator, - ) -> Option { - let mut schedulable_threads = ThreadSet::any(self.num_threads); - - for account in write_account_locks { - schedulable_threads &= self.write_schedulable_threads(account); - if schedulable_threads.is_empty() { - return None; - } - } - - for account in read_account_locks { - schedulable_threads &= self.read_schedulable_threads(account); - if schedulable_threads.is_empty() { - return None; - } - } - - Some(schedulable_threads) - } - - /// Returns `ThreadSet` of schedulable threads for the given readable account. - fn read_schedulable_threads(&self, account: &Pubkey) -> ThreadSet { - self.schedulable_threads::(account) - } - - /// Returns `ThreadSet` of schedulable threads for the given writable account. - fn write_schedulable_threads(&self, account: &Pubkey) -> ThreadSet { - self.schedulable_threads::(account) - } - - /// Returns `ThreadSet` of schedulable threads. - /// If there are no locks, then all threads are schedulable. - /// If only write-locked, then only the thread holding the write lock is schedulable. - /// If a mix of locks, then only the write thread is schedulable. - /// If only read-locked, the only write-schedulable thread is if a single thread - /// holds all read locks. Otherwise, no threads are write-schedulable. - /// If only read-locked, all threads are read-schedulable. - fn schedulable_threads(&self, account: &Pubkey) -> ThreadSet { - match self.locks.get(account) { - None => ThreadSet::any(self.num_threads), - Some(AccountLocks { - write_locks: None, - read_locks: Some(read_locks), - }) => { - if WRITE { - read_locks - .thread_set - .only_one_contained() - .map(ThreadSet::only) - .unwrap_or_else(ThreadSet::none) - } else { - ThreadSet::any(self.num_threads) - } - } - Some(AccountLocks { - write_locks: Some(write_locks), - read_locks: None, - }) => ThreadSet::only(write_locks.thread_id), - Some(AccountLocks { - write_locks: Some(write_locks), - read_locks: Some(read_locks), - }) => { - assert_eq!( - read_locks.thread_set.only_one_contained(), - Some(write_locks.thread_id) - ); - read_locks.thread_set - } - Some(AccountLocks { - write_locks: None, - read_locks: None, - }) => unreachable!(), - } - } - - /// Add locks for all writable and readable accounts on `thread_id`. - fn lock_accounts<'a>( - &mut self, - write_account_locks: impl Iterator, - read_account_locks: impl Iterator, - thread_id: ThreadId, - ) { - assert!( - thread_id < self.num_threads, - "thread_id must be < num_threads" - ); - for account in write_account_locks { - self.write_lock_account(account, thread_id); - } - - for account in read_account_locks { - self.read_lock_account(account, thread_id); - } - } - - /// Locks the given `account` for writing on `thread_id`. - /// Panics if the account is already locked for writing on another thread. - fn write_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { - let entry = self.locks.entry(*account).or_default(); - - let AccountLocks { - write_locks, - read_locks, - } = entry; - - if let Some(read_locks) = read_locks { - assert_eq!( - read_locks.thread_set.only_one_contained(), - Some(thread_id), - "outstanding read lock must be on same thread" - ); - } - - if let Some(write_locks) = write_locks { - assert_eq!( - write_locks.thread_id, thread_id, - "outstanding write lock must be on same thread" - ); - write_locks.lock_count += 1; - } else { - *write_locks = Some(AccountWriteLocks { - thread_id, - lock_count: 1, - }); - } - } - - /// Unlocks the given `account` for writing on `thread_id`. - /// Panics if the account is not locked for writing on `thread_id`. - fn write_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { - let Entry::Occupied(mut entry) = self.locks.entry(*account) else { - panic!("write lock must exist for account: {account}"); - }; - - let AccountLocks { - write_locks: maybe_write_locks, - read_locks, - } = entry.get_mut(); - - let Some(write_locks) = maybe_write_locks else { - panic!("write lock must exist for account: {account}"); - }; - - assert_eq!( - write_locks.thread_id, thread_id, - "outstanding write lock must be on same thread" - ); - - write_locks.lock_count -= 1; - if write_locks.lock_count == 0 { - *maybe_write_locks = None; - if read_locks.is_none() { - entry.remove(); - } - } - } - - /// Locks the given `account` for reading on `thread_id`. - /// Panics if the account is already locked for writing on another thread. - fn read_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { - let AccountLocks { - write_locks, - read_locks, - } = self.locks.entry(*account).or_default(); - - if let Some(write_locks) = write_locks { - assert_eq!( - write_locks.thread_id, thread_id, - "outstanding write lock must be on same thread" - ); - } - - match read_locks { - Some(read_locks) => { - read_locks.thread_set.insert(thread_id); - read_locks.lock_counts[thread_id] += 1; - } - None => { - let mut lock_counts = [0; MAX_THREADS]; - lock_counts[thread_id] = 1; - *read_locks = Some(AccountReadLocks { - thread_set: ThreadSet::only(thread_id), - lock_counts, - }); - } - } - } - - /// Unlocks the given `account` for reading on `thread_id`. - /// Panics if the account is not locked for reading on `thread_id`. - fn read_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { - let Entry::Occupied(mut entry) = self.locks.entry(*account) else { - panic!("read lock must exist for account: {account}"); - }; - - let AccountLocks { - write_locks, - read_locks: maybe_read_locks, - } = entry.get_mut(); - - let Some(read_locks) = maybe_read_locks else { - panic!("read lock must exist for account: {account}"); - }; - - assert!( - read_locks.thread_set.contains(thread_id), - "outstanding read lock must be on same thread" - ); - - read_locks.lock_counts[thread_id] -= 1; - if read_locks.lock_counts[thread_id] == 0 { - read_locks.thread_set.remove(thread_id); - if read_locks.thread_set.is_empty() { - *maybe_read_locks = None; - if write_locks.is_none() { - entry.remove(); - } - } - } - } -} - -impl BitAnd for ThreadSet { - type Output = Self; - - fn bitand(self, rhs: Self) -> Self::Output { - Self(self.0 & rhs.0) - } -} - -impl BitAndAssign for ThreadSet { - fn bitand_assign(&mut self, rhs: Self) { - self.0 &= rhs.0; - } -} - -impl Sub for ThreadSet { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - Self(self.0 & !rhs.0) - } -} - -impl Display for ThreadSet { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "ThreadSet({:#0width$b})", self.0, width = MAX_THREADS) - } -} - -impl Debug for ThreadSet { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Display::fmt(self, f) - } -} - -impl ThreadSet { - #[inline(always)] - pub const fn none() -> Self { - Self(0b0) - } - - #[inline(always)] - pub const fn any(num_threads: usize) -> Self { - if num_threads == MAX_THREADS { - Self(u64::MAX) - } else { - Self(Self::as_flag(num_threads) - 1) - } - } - - #[inline(always)] - pub const fn only(thread_id: ThreadId) -> Self { - Self(Self::as_flag(thread_id)) - } - - #[inline(always)] - pub fn num_threads(&self) -> u32 { - self.0.count_ones() - } - - #[inline(always)] - pub fn only_one_contained(&self) -> Option { - (self.num_threads() == 1).then_some(self.0.trailing_zeros() as ThreadId) - } - - #[inline(always)] - pub fn is_empty(&self) -> bool { - self == &Self::none() - } - - #[inline(always)] - pub fn contains(&self, thread_id: ThreadId) -> bool { - self.0 & Self::as_flag(thread_id) != 0 - } - - #[inline(always)] - pub fn insert(&mut self, thread_id: ThreadId) { - self.0 |= Self::as_flag(thread_id); - } - - #[inline(always)] - pub fn remove(&mut self, thread_id: ThreadId) { - self.0 &= !Self::as_flag(thread_id); - } - - #[inline(always)] - pub fn contained_threads_iter(self) -> impl Iterator { - (0..MAX_THREADS).filter(move |thread_id| self.contains(*thread_id)) - } - - #[inline(always)] - const fn as_flag(thread_id: ThreadId) -> u64 { - 0b1 << thread_id - } -} - -#[cfg(test)] -mod tests { - use super::*; - - const TEST_NUM_THREADS: usize = 4; - const TEST_ANY_THREADS: ThreadSet = ThreadSet::any(TEST_NUM_THREADS); - - // Simple thread selector to select the first schedulable thread - fn test_thread_selector(thread_set: ThreadSet) -> ThreadId { - thread_set.contained_threads_iter().next().unwrap() - } - - #[test] - #[should_panic(expected = "num threads must be > 0")] - fn test_too_few_num_threads() { - ThreadAwareAccountLocks::new(0); - } - - #[test] - #[should_panic(expected = "num threads must be <=")] - fn test_too_many_num_threads() { - ThreadAwareAccountLocks::new(MAX_THREADS + 1); - } - - #[test] - fn test_try_lock_accounts_none() { - let pk1 = Pubkey::new_unique(); - let pk2 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.read_lock_account(&pk1, 2); - locks.read_lock_account(&pk1, 3); - assert_eq!( - locks.try_lock_accounts( - [&pk1].into_iter(), - [&pk2].into_iter(), - TEST_ANY_THREADS, - test_thread_selector - ), - None - ); - } - - #[test] - fn test_try_lock_accounts_one() { - let pk1 = Pubkey::new_unique(); - let pk2 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.write_lock_account(&pk2, 3); - - assert_eq!( - locks.try_lock_accounts( - [&pk1].into_iter(), - [&pk2].into_iter(), - TEST_ANY_THREADS, - test_thread_selector - ), - Some(3) - ); - } - - #[test] - fn test_try_lock_accounts_multiple() { - let pk1 = Pubkey::new_unique(); - let pk2 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.read_lock_account(&pk2, 0); - locks.read_lock_account(&pk2, 0); - - assert_eq!( - locks.try_lock_accounts( - [&pk1].into_iter(), - [&pk2].into_iter(), - TEST_ANY_THREADS - ThreadSet::only(0), // exclude 0 - test_thread_selector - ), - Some(1) - ); - } - - #[test] - fn test_try_lock_accounts_any() { - let pk1 = Pubkey::new_unique(); - let pk2 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - assert_eq!( - locks.try_lock_accounts( - [&pk1].into_iter(), - [&pk2].into_iter(), - TEST_ANY_THREADS, - test_thread_selector - ), - Some(0) - ); - } - - #[test] - fn test_accounts_schedulable_threads_no_outstanding_locks() { - let pk1 = Pubkey::new_unique(); - let locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - - assert_eq!( - locks.accounts_schedulable_threads([&pk1].into_iter(), std::iter::empty()), - Some(TEST_ANY_THREADS) - ); - assert_eq!( - locks.accounts_schedulable_threads(std::iter::empty(), [&pk1].into_iter()), - Some(TEST_ANY_THREADS) - ); - } - - #[test] - fn test_accounts_schedulable_threads_outstanding_write_only() { - let pk1 = Pubkey::new_unique(); - let pk2 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - - locks.write_lock_account(&pk1, 2); - assert_eq!( - locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), - Some(ThreadSet::only(2)) - ); - assert_eq!( - locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), - Some(ThreadSet::only(2)) - ); - } - - #[test] - fn test_accounts_schedulable_threads_outstanding_read_only() { - let pk1 = Pubkey::new_unique(); - let pk2 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - - locks.read_lock_account(&pk1, 2); - assert_eq!( - locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), - Some(ThreadSet::only(2)) - ); - assert_eq!( - locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), - Some(TEST_ANY_THREADS) - ); - - locks.read_lock_account(&pk1, 0); - assert_eq!( - locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), - None - ); - assert_eq!( - locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), - Some(TEST_ANY_THREADS) - ); - } - - #[test] - fn test_accounts_schedulable_threads_outstanding_mixed() { - let pk1 = Pubkey::new_unique(); - let pk2 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - - locks.read_lock_account(&pk1, 2); - locks.write_lock_account(&pk1, 2); - assert_eq!( - locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), - Some(ThreadSet::only(2)) - ); - assert_eq!( - locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), - Some(ThreadSet::only(2)) - ); - } - - #[test] - #[should_panic(expected = "outstanding write lock must be on same thread")] - fn test_write_lock_account_write_conflict_panic() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.write_lock_account(&pk1, 0); - locks.write_lock_account(&pk1, 1); - } - - #[test] - #[should_panic(expected = "outstanding read lock must be on same thread")] - fn test_write_lock_account_read_conflict_panic() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.read_lock_account(&pk1, 0); - locks.write_lock_account(&pk1, 1); - } - - #[test] - #[should_panic(expected = "write lock must exist")] - fn test_write_unlock_account_not_locked() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.write_unlock_account(&pk1, 0); - } - - #[test] - #[should_panic(expected = "outstanding write lock must be on same thread")] - fn test_write_unlock_account_thread_mismatch() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.write_lock_account(&pk1, 1); - locks.write_unlock_account(&pk1, 0); - } - - #[test] - #[should_panic(expected = "outstanding write lock must be on same thread")] - fn test_read_lock_account_write_conflict_panic() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.write_lock_account(&pk1, 0); - locks.read_lock_account(&pk1, 1); - } - - #[test] - #[should_panic(expected = "read lock must exist")] - fn test_read_unlock_account_not_locked() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.read_unlock_account(&pk1, 1); - } - - #[test] - #[should_panic(expected = "outstanding read lock must be on same thread")] - fn test_read_unlock_account_thread_mismatch() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.read_lock_account(&pk1, 0); - locks.read_unlock_account(&pk1, 1); - } - - #[test] - fn test_write_locking() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.write_lock_account(&pk1, 1); - locks.write_lock_account(&pk1, 1); - locks.write_unlock_account(&pk1, 1); - locks.write_unlock_account(&pk1, 1); - assert!(locks.locks.is_empty()); - } - - #[test] - fn test_read_locking() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.read_lock_account(&pk1, 1); - locks.read_lock_account(&pk1, 1); - locks.read_unlock_account(&pk1, 1); - locks.read_unlock_account(&pk1, 1); - assert!(locks.locks.is_empty()); - } - - #[test] - #[should_panic(expected = "thread_id must be < num_threads")] - fn test_lock_accounts_invalid_thread() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.lock_accounts([&pk1].into_iter(), std::iter::empty(), TEST_NUM_THREADS); - } - - #[test] - fn test_thread_set() { - let mut thread_set = ThreadSet::none(); - assert!(thread_set.is_empty()); - assert_eq!(thread_set.num_threads(), 0); - assert_eq!(thread_set.only_one_contained(), None); - for idx in 0..MAX_THREADS { - assert!(!thread_set.contains(idx)); - } - - thread_set.insert(4); - assert!(!thread_set.is_empty()); - assert_eq!(thread_set.num_threads(), 1); - assert_eq!(thread_set.only_one_contained(), Some(4)); - for idx in 0..MAX_THREADS { - assert_eq!(thread_set.contains(idx), idx == 4); - } - - thread_set.insert(2); - assert!(!thread_set.is_empty()); - assert_eq!(thread_set.num_threads(), 2); - assert_eq!(thread_set.only_one_contained(), None); - for idx in 0..MAX_THREADS { - assert_eq!(thread_set.contains(idx), idx == 2 || idx == 4); - } - - thread_set.remove(4); - assert!(!thread_set.is_empty()); - assert_eq!(thread_set.num_threads(), 1); - assert_eq!(thread_set.only_one_contained(), Some(2)); - for idx in 0..MAX_THREADS { - assert_eq!(thread_set.contains(idx), idx == 2); - } - } - - #[test] - fn test_thread_set_any_zero() { - let any_threads = ThreadSet::any(0); - assert_eq!(any_threads.num_threads(), 0); - } - - #[test] - fn test_thread_set_any_max() { - let any_threads = ThreadSet::any(MAX_THREADS); - assert_eq!(any_threads.num_threads(), MAX_THREADS as u32); - } -} diff --git a/prio-graph-scheduler/src/transaction_priority_id.rs b/prio-graph-scheduler/src/transaction_priority_id.rs deleted file mode 100644 index 05720c8..0000000 --- a/prio-graph-scheduler/src/transaction_priority_id.rs +++ /dev/null @@ -1,69 +0,0 @@ -use { - crate::scheduler_messages::TransactionId, - prio_graph::TopLevelId, - std::hash::{Hash, Hasher}, -}; - -/// A unique identifier tied with priority ordering for a transaction/packet: -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct TransactionPriorityId { - pub priority: u64, - pub id: TransactionId, -} - -impl TransactionPriorityId { - pub fn new(priority: u64, id: TransactionId) -> Self { - Self { priority, id } - } -} - -impl Hash for TransactionPriorityId { - fn hash(&self, state: &mut H) { - self.id.hash(state) - } -} - -impl TopLevelId for TransactionPriorityId { - fn id(&self) -> Self { - *self - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_transaction_priority_id_ordering() { - // Higher priority first - { - let id1 = TransactionPriorityId::new(1, TransactionId::new(1)); - let id2 = TransactionPriorityId::new(2, TransactionId::new(1)); - assert!(id1 < id2); - assert!(id1 <= id2); - assert!(id2 > id1); - assert!(id2 >= id1); - } - - // Equal priority then compare by id - { - let id1 = TransactionPriorityId::new(1, TransactionId::new(1)); - let id2 = TransactionPriorityId::new(1, TransactionId::new(2)); - assert!(id1 < id2); - assert!(id1 <= id2); - assert!(id2 > id1); - assert!(id2 >= id1); - } - - // Equal priority and id - { - let id1 = TransactionPriorityId::new(1, TransactionId::new(1)); - let id2 = TransactionPriorityId::new(1, TransactionId::new(1)); - assert_eq!(id1, id2); - assert!(id1 >= id2); - assert!(id1 <= id2); - assert!(id2 >= id1); - assert!(id2 <= id1); - } - } -} diff --git a/prio-graph-scheduler/src/transaction_state.rs b/prio-graph-scheduler/src/transaction_state.rs deleted file mode 100644 index a4985c3..0000000 --- a/prio-graph-scheduler/src/transaction_state.rs +++ /dev/null @@ -1,384 +0,0 @@ -use { - crate::deserializable_packet::DeserializableTxPacket, crate::scheduler_messages::MaxAge, - solana_sdk::transaction::SanitizedTransaction, std::sync::Arc, -}; - -/// Simple wrapper type to tie a sanitized transaction to max age slot. -pub struct SanitizedTransactionTTL { - pub transaction: SanitizedTransaction, - pub max_age: MaxAge, -} - -/// TransactionState is used to track the state of a transaction in the transaction scheduler -/// and banking stage as a whole. -/// -/// There are two states a transaction can be in: -/// 1. `Unprocessed` - The transaction is available for scheduling. -/// 2. `Pending` - The transaction is currently scheduled or being processed. -/// -/// Newly received transactions are initially in the `Unprocessed` state. -/// When a transaction is scheduled, it is transitioned to the `Pending` state, -/// using the `transition_to_pending` method. -/// When a transaction finishes processing it may be retryable. If it is retryable, -/// the transaction is transitioned back to the `Unprocessed` state using the -/// `transition_to_unprocessed` method. If it is not retryable, the state should -/// be dropped. -/// -/// For performance, when a transaction is transitioned to the `Pending` state, the -/// internal `SanitizedTransaction` is moved out of the `TransactionState` and sent -/// to the appropriate thread for processing. This is done to avoid cloning the -/// `SanitizedTransaction`. -#[allow(clippy::large_enum_variant)] -pub enum TransactionState { - /// The transaction is available for scheduling. - Unprocessed { - transaction_ttl: SanitizedTransactionTTL, - packet: Arc

, - priority: u64, - cost: u64, - should_forward: bool, - }, - /// The transaction is currently scheduled or being processed. - Pending { - packet: Arc

, - priority: u64, - cost: u64, - should_forward: bool, - }, - /// Only used during transition. - Transitioning, -} - -impl TransactionState

{ - /// Creates a new `TransactionState` in the `Unprocessed` state. - pub fn new( - transaction_ttl: SanitizedTransactionTTL, - packet: Arc

, - priority: u64, - cost: u64, - ) -> Self { - let should_forward = !packet.original_packet().meta().forwarded() - && packet.original_packet().meta().is_from_staked_node(); - Self::Unprocessed { - transaction_ttl, - packet, - priority, - cost, - should_forward, - } - } - - /// Return the priority of the transaction. - /// This is *not* the same as the `compute_unit_price` of the transaction. - /// The priority is used to order transactions for processing. - pub fn priority(&self) -> u64 { - match self { - Self::Unprocessed { priority, .. } => *priority, - Self::Pending { priority, .. } => *priority, - Self::Transitioning => unreachable!(), - } - } - - /// Return the cost of the transaction. - pub fn cost(&self) -> u64 { - match self { - Self::Unprocessed { cost, .. } => *cost, - Self::Pending { cost, .. } => *cost, - Self::Transitioning => unreachable!(), - } - } - - /// Return whether packet should be attempted to be forwarded. - pub fn should_forward(&self) -> bool { - match self { - Self::Unprocessed { - should_forward: forwarded, - .. - } => *forwarded, - Self::Pending { - should_forward: forwarded, - .. - } => *forwarded, - Self::Transitioning => unreachable!(), - } - } - - /// Mark the packet as forwarded. - /// This is used to prevent the packet from being forwarded multiple times. - pub fn mark_forwarded(&mut self) { - match self { - Self::Unprocessed { should_forward, .. } => *should_forward = false, - Self::Pending { should_forward, .. } => *should_forward = false, - Self::Transitioning => unreachable!(), - } - } - - /// Return the packet of the transaction. - pub fn packet(&self) -> &Arc

{ - match self { - Self::Unprocessed { packet, .. } => packet, - Self::Pending { packet, .. } => packet, - Self::Transitioning => unreachable!(), - } - } - - /// Intended to be called when a transaction is scheduled. This method will - /// transition the transaction from `Unprocessed` to `Pending` and return the - /// `SanitizedTransactionTTL` for processing. - /// - /// # Panics - /// This method will panic if the transaction is already in the `Pending` state, - /// as this is an invalid state transition. - pub fn transition_to_pending(&mut self) -> SanitizedTransactionTTL { - match self.take() { - TransactionState::Unprocessed { - transaction_ttl, - packet, - priority, - cost, - should_forward: forwarded, - } => { - *self = TransactionState::Pending { - packet, - priority, - cost, - should_forward: forwarded, - }; - transaction_ttl - } - TransactionState::Pending { .. } => { - panic!("transaction already pending"); - } - Self::Transitioning => unreachable!(), - } - } - - /// Intended to be called when a transaction is retried. This method will - /// transition the transaction from `Pending` to `Unprocessed`. - /// - /// # Panics - /// This method will panic if the transaction is already in the `Unprocessed` - /// state, as this is an invalid state transition. - pub fn transition_to_unprocessed(&mut self, transaction_ttl: SanitizedTransactionTTL) { - match self.take() { - TransactionState::Unprocessed { .. } => panic!("already unprocessed"), - TransactionState::Pending { - packet, - priority, - cost, - should_forward: forwarded, - } => { - *self = Self::Unprocessed { - transaction_ttl, - packet, - priority, - cost, - should_forward: forwarded, - } - } - Self::Transitioning => unreachable!(), - } - } - - /// Get a reference to the `SanitizedTransactionTTL` for the transaction. - /// - /// # Panics - /// This method will panic if the transaction is in the `Pending` state. - pub fn transaction_ttl(&self) -> &SanitizedTransactionTTL { - match self { - Self::Unprocessed { - transaction_ttl, .. - } => transaction_ttl, - Self::Pending { .. } => panic!("transaction is pending"), - Self::Transitioning => unreachable!(), - } - } - - /// Internal helper to transitioning between states. - /// Replaces `self` with a dummy state that will immediately be overwritten in transition. - fn take(&mut self) -> Self { - core::mem::replace(self, Self::Transitioning) - } -} - -#[cfg(test)] -mod tests { - use { - super::*, - crate::tests::MockImmutableDeserializedPacket, - solana_sdk::{ - clock::Slot, compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, - packet::Packet, signature::Keypair, signer::Signer, system_instruction, - transaction::Transaction, - }, - }; - - fn create_transaction_state( - compute_unit_price: u64, - ) -> TransactionState { - let from_keypair = Keypair::new(); - let ixs = vec![ - system_instruction::transfer( - &from_keypair.pubkey(), - &solana_sdk::pubkey::new_rand(), - 1, - ), - ComputeBudgetInstruction::set_compute_unit_price(compute_unit_price), - ]; - let message = Message::new(&ixs, Some(&from_keypair.pubkey())); - let tx = Transaction::new(&[&from_keypair], message, Hash::default()); - - let packet = Arc::new( - MockImmutableDeserializedPacket::new(Packet::from_data(None, tx.clone()).unwrap()) - .unwrap(), - ); - let transaction_ttl = SanitizedTransactionTTL { - transaction: SanitizedTransaction::from_transaction_for_tests(tx), - max_age: MaxAge { - epoch_invalidation_slot: Slot::MAX, - alt_invalidation_slot: Slot::MAX, - }, - }; - const TEST_TRANSACTION_COST: u64 = 5000; - TransactionState::new( - transaction_ttl, - packet, - compute_unit_price, - TEST_TRANSACTION_COST, - ) - } - - #[test] - #[should_panic(expected = "already pending")] - fn test_transition_to_pending_panic() { - let mut transaction_state = create_transaction_state(0); - transaction_state.transition_to_pending(); - transaction_state.transition_to_pending(); // invalid transition - } - - #[test] - fn test_transition_to_pending() { - let mut transaction_state = create_transaction_state(0); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - let _ = transaction_state.transition_to_pending(); - assert!(matches!( - transaction_state, - TransactionState::Pending { .. } - )); - } - - #[test] - #[should_panic(expected = "already unprocessed")] - fn test_transition_to_unprocessed_panic() { - let mut transaction_state = create_transaction_state(0); - - // Manually clone `SanitizedTransactionTTL` - let SanitizedTransactionTTL { - transaction, - max_age, - } = transaction_state.transaction_ttl(); - let transaction_ttl = SanitizedTransactionTTL { - transaction: transaction.clone(), - max_age: *max_age, - }; - transaction_state.transition_to_unprocessed(transaction_ttl); // invalid transition - } - - #[test] - fn test_transition_to_unprocessed() { - let mut transaction_state = create_transaction_state(0); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - let transaction_ttl = transaction_state.transition_to_pending(); - assert!(matches!( - transaction_state, - TransactionState::Pending { .. } - )); - transaction_state.transition_to_unprocessed(transaction_ttl); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - } - - #[test] - fn test_priority() { - let priority = 15; - let mut transaction_state = create_transaction_state(priority); - assert_eq!(transaction_state.priority(), priority); - - // ensure compute unit price is not lost through state transitions - let transaction_ttl = transaction_state.transition_to_pending(); - assert_eq!(transaction_state.priority(), priority); - transaction_state.transition_to_unprocessed(transaction_ttl); - assert_eq!(transaction_state.priority(), priority); - } - - #[test] - #[should_panic(expected = "transaction is pending")] - fn test_transaction_ttl_panic() { - let mut transaction_state = create_transaction_state(0); - let transaction_ttl = transaction_state.transaction_ttl(); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - assert_eq!( - transaction_ttl.max_age, - MaxAge { - epoch_invalidation_slot: Slot::MAX, - alt_invalidation_slot: Slot::MAX, - } - ); - - let _ = transaction_state.transition_to_pending(); - assert!(matches!( - transaction_state, - TransactionState::Pending { .. } - )); - let _ = transaction_state.transaction_ttl(); // pending state, the transaction ttl is not available - } - - #[test] - fn test_transaction_ttl() { - let mut transaction_state = create_transaction_state(0); - let transaction_ttl = transaction_state.transaction_ttl(); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - assert_eq!( - transaction_ttl.max_age, - MaxAge { - epoch_invalidation_slot: Slot::MAX, - alt_invalidation_slot: Slot::MAX, - } - ); - - // ensure transaction_ttl is not lost through state transitions - let transaction_ttl = transaction_state.transition_to_pending(); - assert!(matches!( - transaction_state, - TransactionState::Pending { .. } - )); - - transaction_state.transition_to_unprocessed(transaction_ttl); - let transaction_ttl = transaction_state.transaction_ttl(); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - assert_eq!( - transaction_ttl.max_age, - MaxAge { - epoch_invalidation_slot: Slot::MAX, - alt_invalidation_slot: Slot::MAX, - } - ); - } -} diff --git a/prio-graph-scheduler/src/transaction_state_container.rs b/prio-graph-scheduler/src/transaction_state_container.rs deleted file mode 100644 index 2f11629..0000000 --- a/prio-graph-scheduler/src/transaction_state_container.rs +++ /dev/null @@ -1,263 +0,0 @@ -use { - super::{ - transaction_priority_id::TransactionPriorityId, - transaction_state::{SanitizedTransactionTTL, TransactionState}, - }, - crate::{deserializable_packet::DeserializableTxPacket, scheduler_messages::TransactionId}, - itertools::MinMaxResult, - min_max_heap::MinMaxHeap, - std::{collections::HashMap, sync::Arc}, -}; - -/// This structure will hold `TransactionState` for the entirety of a -/// transaction's lifetime in the scheduler and BankingStage as a whole. -/// -/// Transaction Lifetime: -/// 1. Received from `SigVerify` by `BankingStage` -/// 2. Inserted into `TransactionStateContainer` by `BankingStage` -/// 3. Popped in priority-order by scheduler, and transitioned to `Pending` state -/// 4. Processed by `ConsumeWorker` -/// a. If consumed, remove `Pending` state from the `TransactionStateContainer` -/// b. If retryable, transition back to `Unprocessed` state. -/// Re-insert to the queue, and return to step 3. -/// -/// The structure is composed of two main components: -/// 1. A priority queue of wrapped `TransactionId`s, which are used to -/// order transactions by priority for selection by the scheduler. -/// 2. A map of `TransactionId` to `TransactionState`, which is used to -/// track the state of each transaction. -/// -/// When `Pending`, the associated `TransactionId` is not in the queue, but -/// is still in the map. -/// The entry in the map should exist before insertion into the queue, and be -/// be removed only after the id is removed from the queue. -/// -/// The container maintains a fixed capacity. If the queue is full when pushing -/// a new transaction, the lowest priority transaction will be dropped. -pub struct TransactionStateContainer { - priority_queue: MinMaxHeap, - id_to_transaction_state: HashMap>, -} - -impl TransactionStateContainer

{ - pub fn with_capacity(capacity: usize) -> Self { - Self { - priority_queue: MinMaxHeap::with_capacity(capacity), - id_to_transaction_state: HashMap::with_capacity(capacity), - } - } - - /// Returns true if the queue is empty. - pub fn is_empty(&self) -> bool { - self.priority_queue.is_empty() - } - - /// Returns the remaining capacity of the queue - pub fn remaining_queue_capacity(&self) -> usize { - self.priority_queue.capacity() - self.priority_queue.len() - } - - /// Get the top transaction id in the priority queue. - pub fn pop(&mut self) -> Option { - self.priority_queue.pop_max() - } - - /// Get mutable transaction state by id. - pub fn get_mut_transaction_state( - &mut self, - id: &TransactionId, - ) -> Option<&mut TransactionState

> { - self.id_to_transaction_state.get_mut(id) - } - - /// Get reference to `SanitizedTransactionTTL` by id. - /// Panics if the transaction does not exist. - pub fn get_transaction_ttl(&self, id: &TransactionId) -> Option<&SanitizedTransactionTTL> { - self.id_to_transaction_state - .get(id) - .map(|state| state.transaction_ttl()) - } - - /// Insert a new transaction into the container's queues and maps. - /// Returns `true` if a packet was dropped due to capacity limits. - pub fn insert_new_transaction( - &mut self, - transaction_id: TransactionId, - transaction_ttl: SanitizedTransactionTTL, - packet: Arc

, - priority: u64, - cost: u64, - ) -> bool { - let priority_id = TransactionPriorityId::new(priority, transaction_id); - self.id_to_transaction_state.insert( - transaction_id, - TransactionState::new(transaction_ttl, packet, priority, cost), - ); - self.push_id_into_queue(priority_id) - } - - /// Retries a transaction - inserts transaction back into map (but not packet). - /// This transitions the transaction to `Unprocessed` state. - pub fn retry_transaction( - &mut self, - transaction_id: TransactionId, - transaction_ttl: SanitizedTransactionTTL, - ) { - let transaction_state = self - .get_mut_transaction_state(&transaction_id) - .expect("transaction must exist"); - let priority_id = TransactionPriorityId::new(transaction_state.priority(), transaction_id); - transaction_state.transition_to_unprocessed(transaction_ttl); - self.push_id_into_queue(priority_id); - } - - /// Pushes a transaction id into the priority queue. If the queue is full, the lowest priority - /// transaction will be dropped (removed from the queue and map). - /// Returns `true` if a packet was dropped due to capacity limits. - pub fn push_id_into_queue(&mut self, priority_id: TransactionPriorityId) -> bool { - if self.remaining_queue_capacity() == 0 { - let popped_id = self.priority_queue.push_pop_min(priority_id); - self.remove_by_id(&popped_id.id); - true - } else { - self.priority_queue.push(priority_id); - false - } - } - - /// Remove transaction by id. - pub fn remove_by_id(&mut self, id: &TransactionId) { - self.id_to_transaction_state - .remove(id) - .expect("transaction must exist"); - } - - pub fn get_min_max_priority(&self) -> MinMaxResult { - match self.priority_queue.peek_min() { - Some(min) => match self.priority_queue.peek_max() { - Some(max) => MinMaxResult::MinMax(min.priority, max.priority), - None => MinMaxResult::OneElement(min.priority), - }, - None => MinMaxResult::NoElements, - } - } -} - -#[cfg(test)] -mod tests { - use { - super::*, - crate::scheduler_messages::MaxAge, - crate::tests::MockImmutableDeserializedPacket, - solana_sdk::{ - compute_budget::ComputeBudgetInstruction, - hash::Hash, - message::Message, - packet::Packet, - signature::Keypair, - signer::Signer, - slot_history::Slot, - system_instruction, - transaction::{SanitizedTransaction, Transaction}, - }, - }; - - /// Returns (transaction_ttl, priority, cost) - fn test_transaction( - priority: u64, - ) -> ( - SanitizedTransactionTTL, - Arc, - u64, - u64, - ) { - let from_keypair = Keypair::new(); - let ixs = vec![ - system_instruction::transfer( - &from_keypair.pubkey(), - &solana_sdk::pubkey::new_rand(), - 1, - ), - ComputeBudgetInstruction::set_compute_unit_price(priority), - ]; - let message = Message::new(&ixs, Some(&from_keypair.pubkey())); - let tx = SanitizedTransaction::from_transaction_for_tests(Transaction::new( - &[&from_keypair], - message, - Hash::default(), - )); - let packet = Arc::new( - MockImmutableDeserializedPacket::new( - Packet::from_data(None, tx.to_versioned_transaction()).unwrap(), - ) - .unwrap(), - ); - let transaction_ttl = SanitizedTransactionTTL { - transaction: tx, - max_age: MaxAge { - epoch_invalidation_slot: Slot::MAX, - alt_invalidation_slot: Slot::MAX, - }, - }; - const TEST_TRANSACTION_COST: u64 = 5000; - (transaction_ttl, packet, priority, TEST_TRANSACTION_COST) - } - - fn push_to_container( - container: &mut TransactionStateContainer, - num: usize, - ) { - for id in 0..num as u64 { - let priority = id; - let (transaction_ttl, packet, priority, cost) = test_transaction(priority); - container.insert_new_transaction( - TransactionId::new(id), - transaction_ttl, - packet, - priority, - cost, - ); - } - } - - #[test] - fn test_is_empty() { - let mut container = TransactionStateContainer::with_capacity(1); - assert!(container.is_empty()); - - push_to_container(&mut container, 1); - assert!(!container.is_empty()); - } - - #[test] - fn test_priority_queue_capacity() { - let mut container = TransactionStateContainer::with_capacity(1); - push_to_container(&mut container, 5); - - assert_eq!(container.priority_queue.len(), 1); - assert_eq!(container.id_to_transaction_state.len(), 1); - assert_eq!( - container - .id_to_transaction_state - .iter() - .map(|ts| ts.1.priority()) - .next() - .unwrap(), - 4 - ); - } - - #[test] - fn test_get_mut_transaction_state() { - let mut container = TransactionStateContainer::with_capacity(5); - push_to_container(&mut container, 5); - - let existing_id = TransactionId::new(3); - let non_existing_id = TransactionId::new(7); - assert!(container.get_mut_transaction_state(&existing_id).is_some()); - assert!(container.get_mut_transaction_state(&existing_id).is_some()); - assert!(container - .get_mut_transaction_state(&non_existing_id) - .is_none()); - } -} diff --git a/rpc/src/lib.rs b/rpc/src/lib.rs index e80da9f..bf0051e 100644 --- a/rpc/src/lib.rs +++ b/rpc/src/lib.rs @@ -1,4 +1,5 @@ #![allow(dead_code)] +#![allow(clippy::result_large_err)] #[macro_use] extern crate log; diff --git a/scheduler/bin/scheduling_simulation.rs b/scheduler/bin/scheduling_simulation.rs index 4ba71ac..12cfacb 100644 --- a/scheduler/bin/scheduling_simulation.rs +++ b/scheduler/bin/scheduling_simulation.rs @@ -1,11 +1,9 @@ use crossbeam_channel::{unbounded, Receiver, Sender}; use igloo_executor::processor::TransactionProcessor; use igloo_scheduler::id_generator::IdGenerator; -use igloo_scheduler::impls::no_lock_scheduler::NoLockScheduler; +use igloo_scheduler::impls::prio_graph_scheduler::PrioGraphSchedulerWrapper; use igloo_scheduler::scheduler::Scheduler; -use igloo_scheduler::scheduler_messages::{ - SchedulingBatch, SchedulingBatchResult, -}; +use igloo_scheduler::scheduler_messages::{MaxAge, SchedulingBatch, SchedulingBatchResult}; use igloo_scheduler::status_slicing::{ calculate_thread_load_summary, SvmWorkerSlicingStatus, WorkerStatusUpdate, }; @@ -48,7 +46,7 @@ fn mocking_transfer_tx( )) } -const TOTAL_TX_NUM: usize = 1024 * 16; +const TOTAL_TX_NUM: usize = 1024 * 4; const TOTAL_WORKER_NUM: usize = 4; // each tx need 2 unique accounts. const NUM_ACCOUNTS: usize = TOTAL_TX_NUM * 2; @@ -121,7 +119,7 @@ fn worker_process( // it's ok to ignore send error. // because error is handled by the scheduler. // if scheduler exits, means all task is scheduled. - // no need to maintain lock now. + // no need to maintain channel now. let _ = completed_sender.send(result); // Update idle_start for next iteration @@ -170,8 +168,7 @@ fn main() -> Result<(), E> { let recent_hash = store.current_bank().last_blockhash(); let transfer_txs = accounts .chunks(2) - .enumerate() - .map(|(_, chunk)| { + .map(|chunk| { mocking_transfer_tx(&chunk[0].0, &chunk[1].0.pubkey(), 1e9 as u64, recent_hash) }) .collect::, _>>()?; @@ -203,26 +200,29 @@ fn main() -> Result<(), E> { let mut batch_id_gen = IdGenerator::default(); let mut tx_id_gen = IdGenerator::default(); - let mut scheduler = NoLockScheduler::new(senders.clone(), completed_receiver); + let mut scheduler = PrioGraphSchedulerWrapper::new(senders.clone(), completed_receiver); for chunk in transfer_txs .into_iter() .chunks(SCHEDULER_BATCH_SIZE) .into_iter() .map(|chunk| chunk.collect()) .map(|transactions: Vec<_>| { + let len = transactions.len(); let ids = transactions .iter() - .map(|_| tx_id_gen.next()) + .map(|_| tx_id_gen.gen()) .collect::>(); SchedulingBatch { - batch_id: batch_id_gen.next(), + batch_id: batch_id_gen.gen(), ids, transactions, + max_ages: vec![MaxAge::default(); len], } }) .collect::>() { - scheduler.schedule_batch(chunk); + scheduler.schedule_batch(chunk)?; + scheduler.receive_complete()?; } // Close senders to signal workers to finish diff --git a/scheduler/src/id_generator.rs b/scheduler/src/id_generator.rs index 3090e4e..b0429fa 100644 --- a/scheduler/src/id_generator.rs +++ b/scheduler/src/id_generator.rs @@ -11,7 +11,7 @@ impl Default for IdGenerator { } impl IdGenerator { - pub fn next>(&mut self) -> T { + pub fn gen>(&mut self) -> T { let id = self.next_id; self.next_id = self.next_id.wrapping_sub(1); T::from(id) diff --git a/scheduler/src/impls/no_lock_scheduler/mod.rs b/scheduler/src/impls/no_lock_scheduler/mod.rs index c90f36a..ff96ea3 100644 --- a/scheduler/src/impls/no_lock_scheduler/mod.rs +++ b/scheduler/src/impls/no_lock_scheduler/mod.rs @@ -1,3 +1,4 @@ +use crate::impls::prio_graph_scheduler::scheduler_error::SchedulerError; use crate::scheduler::Scheduler; use crate::scheduler_messages::{SchedulingBatch, SchedulingBatchResult}; use crossbeam_channel::{Receiver, Sender}; @@ -19,7 +20,7 @@ impl Scheduler for NoLockScheduler { } } - fn schedule_batch(&mut self, txs: SchedulingBatch) { + fn schedule_batch(&mut self, txs: SchedulingBatch) -> Result<(), SchedulerError> { let exec_batch = 64; txs.transactions .chunks(exec_batch) @@ -31,10 +32,14 @@ impl Scheduler for NoLockScheduler { batch_id: txs.batch_id, ids: txs.ids[i * exec_batch..(i + 1) * exec_batch].to_vec(), transactions: chunk.to_vec(), + max_ages: vec![], }; self.task_senders[worker_id].send(batch).unwrap(); }); + Ok(()) } - fn receive_complete(&mut self, receipt: SchedulingBatchResult) {} + fn receive_complete(&mut self) -> Result<(), SchedulerError> { + Ok(()) + } } diff --git a/scheduler/src/impls/prio_graph_scheduler/in_flight_tracker.rs b/scheduler/src/impls/prio_graph_scheduler/in_flight_tracker.rs index 13575b4..329ba90 100644 --- a/scheduler/src/impls/prio_graph_scheduler/in_flight_tracker.rs +++ b/scheduler/src/impls/prio_graph_scheduler/in_flight_tracker.rs @@ -47,7 +47,7 @@ impl InFlightTracker { total_cus: u64, thread_id: ThreadId, ) -> TransactionBatchId { - let batch_id = self.batch_id_generator.next(); + let batch_id = self.batch_id_generator.gen(); self.num_in_flight_per_thread[thread_id] += num_transactions; self.cus_in_flight_per_thread[thread_id] += total_cus; self.batches.insert( diff --git a/scheduler/src/impls/prio_graph_scheduler/mod.rs b/scheduler/src/impls/prio_graph_scheduler/mod.rs index 55e0179..dbc9a54 100644 --- a/scheduler/src/impls/prio_graph_scheduler/mod.rs +++ b/scheduler/src/impls/prio_graph_scheduler/mod.rs @@ -1,34 +1,70 @@ pub mod in_flight_tracker; -pub mod prio_graph_scheduler; pub mod read_write_account_set; +pub mod scheduler; pub mod scheduler_error; pub mod scheduler_metrics; pub mod thread_aware_account_locks; pub mod transaction_priority_id; -pub mod transaction_state_container; pub mod transaction_state; +pub mod transaction_state_container; +use crate::impls::prio_graph_scheduler::scheduler::PrioGraphScheduler; +use crate::impls::prio_graph_scheduler::scheduler_error::SchedulerError; +use crate::impls::prio_graph_scheduler::transaction_state::SanitizedTransactionTTL; +use crate::impls::prio_graph_scheduler::transaction_state_container::TransactionStateContainer; use crate::scheduler::Scheduler; use crate::scheduler_messages::{SchedulingBatch, SchedulingBatchResult}; use crossbeam_channel::{Receiver, Sender}; -pub const TARGET_NUM_TRANSACTIONS_PER_BATCH: i32 = 128; +pub const TARGET_NUM_TRANSACTIONS_PER_BATCH: usize = 128; -pub struct PrioGraphSchedulerWrapper {} +pub struct PrioGraphSchedulerWrapper { + inner: PrioGraphScheduler, + container: TransactionStateContainer, +} impl Scheduler for PrioGraphSchedulerWrapper { fn new( schedule_task_senders: Vec>, task_finished_receivers: Receiver, ) -> Self { - todo!() + Self { + inner: PrioGraphScheduler::new(schedule_task_senders, task_finished_receivers), + container: TransactionStateContainer::with_capacity(10240), + } } - fn schedule_batch(&mut self, txs: SchedulingBatch) { - todo!() + fn schedule_batch(&mut self, mut txs: SchedulingBatch) -> Result<(), SchedulerError> { + for ((tx, tx_id), max_age) in txs + .transactions + .drain(..) + .zip(txs.ids.drain(..)) + .zip(txs.max_ages.drain(..)) + { + self.container.insert_new_transaction( + tx_id, + SanitizedTransactionTTL { + transaction: tx, + max_age, + }, + // TODO migrate priority + 0, + 100, + ); + } + + self.inner.schedule( + &mut self.container, + // TODO: migrate pre-filter transactions + |_, result| result.fill(true), + |_| true, + )?; + Ok(()) } - fn receive_complete(&mut self, receipt: SchedulingBatchResult) { - todo!() + fn receive_complete(&mut self) -> Result<(), SchedulerError> { + // TODO metrics logic + self.inner.receive_completed(&mut self.container)?; + Ok(()) } } diff --git a/scheduler/src/impls/prio_graph_scheduler/prio_graph_scheduler.rs b/scheduler/src/impls/prio_graph_scheduler/prio_graph_scheduler.rs deleted file mode 100644 index c4991ed..0000000 --- a/scheduler/src/impls/prio_graph_scheduler/prio_graph_scheduler.rs +++ /dev/null @@ -1,910 +0,0 @@ -// use { -// crate::{ -// impls::prio_graph_scheduler::in_flight_tracker::InFlightTracker, -// impls::prio_graph_scheduler::read_write_account_set::ReadWriteAccountSet, -// impls::prio_graph_scheduler::scheduler_error::SchedulerError, -// impls::prio_graph_scheduler::thread_aware_account_locks::{ -// ThreadAwareAccountLocks, ThreadId, ThreadSet, -// }, -// impls::prio_graph_scheduler::transaction_priority_id::TransactionPriorityId, -// impls::prio_graph_scheduler::TARGET_NUM_TRANSACTIONS_PER_BATCH, -// scheduler_messages::{ -// SchedulingBatch, SchedulingBatchResult, TransactionBatchId, TransactionId, -// }, -// transaction_state::{SanitizedTransactionTTL, TransactionState}, -// transaction_state_container::TransactionStateContainer, -// }, -// crossbeam_channel::{Receiver, Sender, TryRecvError}, -// itertools::izip, -// prio_graph::{AccessKind, PrioGraph}, -// solana_cost_model::block_cost_limits::MAX_BLOCK_UNITS, -// solana_measure::measure_us, -// solana_sdk::{pubkey::Pubkey, saturating_add_assign, transaction::SanitizedTransaction}, -// }; -// -// pub struct PrioGraphScheduler { -// in_flight_tracker: InFlightTracker, -// account_locks: ThreadAwareAccountLocks, -// consume_work_senders: Vec>, -// finished_consume_work_receiver: Receiver, -// look_ahead_window_size: usize, -// } -// -// impl PrioGraphScheduler { -// pub fn new( -// consume_work_senders: Vec>, -// finished_consume_work_receiver: Receiver, -// ) -> Self { -// let num_threads = consume_work_senders.len(); -// Self { -// in_flight_tracker: InFlightTracker::new(num_threads), -// account_locks: ThreadAwareAccountLocks::new(num_threads), -// consume_work_senders, -// finished_consume_work_receiver, -// look_ahead_window_size: 2048, -// } -// } -// -// /// Schedule transactions from the given `TransactionStateContainer` to be -// /// consumed by the worker threads. Returns summary of scheduling, or an -// /// error. -// /// `pre_graph_filter` is used to filter out transactions that should be -// /// skipped and dropped before insertion to the prio-graph. This fn should -// /// set `false` for transactions that should be dropped, and `true` -// /// otherwise. -// /// `pre_lock_filter` is used to filter out transactions after they have -// /// made it to the top of the prio-graph, and immediately before locks are -// /// checked and taken. This fn should return `true` for transactions that -// /// should be scheduled, and `false` otherwise. -// /// -// /// Uses a `PrioGraph` to perform look-ahead during the scheduling of transactions. -// /// This, combined with internal tracking of threads' in-flight transactions, allows -// /// for load-balancing while prioritizing scheduling transactions onto threads that will -// /// not cause conflicts in the near future. -// pub fn schedule( -// &mut self, -// container: &mut TransactionStateContainer, -// pre_graph_filter: impl Fn(&[&SanitizedTransaction], &mut [bool]), -// pre_lock_filter: impl Fn(&SanitizedTransaction) -> bool, -// ) -> Result { -// let num_threads = self.consume_work_senders.len(); -// let max_cu_per_thread = MAX_BLOCK_UNITS / num_threads as u64; -// -// let mut schedulable_threads = ThreadSet::any(num_threads); -// for thread_id in 0..num_threads { -// if self.in_flight_tracker.cus_in_flight_per_thread()[thread_id] >= max_cu_per_thread { -// schedulable_threads.remove(thread_id); -// } -// } -// if schedulable_threads.is_empty() { -// return Ok(SchedulingSummary { -// num_scheduled: 0, -// num_unschedulable: 0, -// num_filtered_out: 0, -// filter_time_us: 0, -// }); -// } -// -// let mut batches = Batches::new(num_threads); -// // Some transactions may be unschedulable due to multi-thread conflicts. -// // These transactions cannot be scheduled until some conflicting work is completed. -// // However, the scheduler should not allow other transactions that conflict with -// // these transactions to be scheduled before them. -// let mut unschedulable_ids = Vec::new(); -// let mut blocking_locks = ReadWriteAccountSet::default(); -// let mut prio_graph = PrioGraph::new(|id: &TransactionPriorityId, _graph_node| *id); -// -// // Track metrics on filter. -// let mut num_filtered_out: usize = 0; -// let mut total_filter_time_us: u64 = 0; -// -// let mut window_budget = self.look_ahead_window_size; -// let mut chunked_pops = |container: &mut TransactionStateContainer

, -// prio_graph: &mut PrioGraph<_, _, _, _>, -// window_budget: &mut usize| { -// while *window_budget > 0 { -// const MAX_FILTER_CHUNK_SIZE: usize = 128; -// let mut filter_array = [true; MAX_FILTER_CHUNK_SIZE]; -// let mut ids = Vec::with_capacity(MAX_FILTER_CHUNK_SIZE); -// let mut txs = Vec::with_capacity(MAX_FILTER_CHUNK_SIZE); -// -// let chunk_size = (*window_budget).min(MAX_FILTER_CHUNK_SIZE); -// for _ in 0..chunk_size { -// if let Some(id) = container.pop() { -// ids.push(id); -// } else { -// break; -// } -// } -// *window_budget = window_budget.saturating_sub(chunk_size); -// -// ids.iter().for_each(|id| { -// let transaction = container.get_transaction_ttl(&id.id).unwrap(); -// txs.push(&transaction.transaction); -// }); -// -// let (_, filter_us) = -// measure_us!(pre_graph_filter(&txs, &mut filter_array[..chunk_size])); -// saturating_add_assign!(total_filter_time_us, filter_us); -// -// for (id, filter_result) in ids.iter().zip(&filter_array[..chunk_size]) { -// if *filter_result { -// let transaction = container.get_transaction_ttl(&id.id).unwrap(); -// prio_graph.insert_transaction( -// *id, -// Self::get_transaction_account_access(transaction), -// ); -// } else { -// saturating_add_assign!(num_filtered_out, 1); -// container.remove_by_id(&id.id); -// } -// } -// -// if ids.len() != chunk_size { -// break; -// } -// } -// }; -// -// // Create the initial look-ahead window. -// // Check transactions against filter, remove from container if it fails. -// chunked_pops(container, &mut prio_graph, &mut window_budget); -// -// let mut unblock_this_batch = -// Vec::with_capacity(self.consume_work_senders.len() * TARGET_NUM_TRANSACTIONS_PER_BATCH); -// const MAX_TRANSACTIONS_PER_SCHEDULING_PASS: usize = 100_000; -// let mut num_scheduled: usize = 0; -// let mut num_sent: usize = 0; -// let mut num_unschedulable: usize = 0; -// while num_scheduled < MAX_TRANSACTIONS_PER_SCHEDULING_PASS { -// // If nothing is in the main-queue of the `PrioGraph` then there's nothing left to schedule. -// if prio_graph.is_empty() { -// break; -// } -// -// while let Some(id) = prio_graph.pop() { -// unblock_this_batch.push(id); -// -// // Should always be in the container, during initial testing phase panic. -// // Later, we can replace with a continue in case this does happen. -// let Some(transaction_state) = container.get_mut_transaction_state(&id.id) else { -// panic!("transaction state must exist") -// }; -// -// let maybe_schedule_info = try_schedule_transaction( -// transaction_state, -// &pre_lock_filter, -// &mut blocking_locks, -// &mut self.account_locks, -// num_threads, -// |thread_set| { -// Self::select_thread( -// thread_set, -// &batches.total_cus, -// self.in_flight_tracker.cus_in_flight_per_thread(), -// &batches.transactions, -// self.in_flight_tracker.num_in_flight_per_thread(), -// ) -// }, -// ); -// -// match maybe_schedule_info { -// Err(TransactionSchedulingError::Filtered) => { -// container.remove_by_id(&id.id); -// } -// Err(TransactionSchedulingError::UnschedulableConflicts) => { -// unschedulable_ids.push(id); -// saturating_add_assign!(num_unschedulable, 1); -// } -// Ok(TransactionSchedulingInfo { -// thread_id, -// transaction, -// max_age, -// cost, -// }) => { -// saturating_add_assign!(num_scheduled, 1); -// batches.transactions[thread_id].push(transaction); -// batches.ids[thread_id].push(id.id); -// batches.max_ages[thread_id].push(max_age); -// saturating_add_assign!(batches.total_cus[thread_id], cost); -// -// // If target batch size is reached, send only this batch. -// if batches.ids[thread_id].len() >= TARGET_NUM_TRANSACTIONS_PER_BATCH { -// saturating_add_assign!( -// num_sent, -// self.send_batch(&mut batches, thread_id)? -// ); -// } -// -// // if the thread is at max_cu_per_thread, remove it from the schedulable threads -// // if there are no more schedulable threads, stop scheduling. -// if self.in_flight_tracker.cus_in_flight_per_thread()[thread_id] -// + batches.total_cus[thread_id] -// >= max_cu_per_thread -// { -// schedulable_threads.remove(thread_id); -// if schedulable_threads.is_empty() { -// break; -// } -// } -// -// if num_scheduled >= MAX_TRANSACTIONS_PER_SCHEDULING_PASS { -// break; -// } -// } -// } -// } -// -// // Send all non-empty batches -// saturating_add_assign!(num_sent, self.send_batches(&mut batches)?); -// -// // Refresh window budget and do chunked pops -// saturating_add_assign!(window_budget, unblock_this_batch.len()); -// chunked_pops(container, &mut prio_graph, &mut window_budget); -// -// // Unblock all transactions that were blocked by the transactions that were just sent. -// for id in unblock_this_batch.drain(..) { -// prio_graph.unblock(&id); -// } -// } -// -// // Send batches for any remaining transactions -// saturating_add_assign!(num_sent, self.send_batches(&mut batches)?); -// -// // Push unschedulable ids back into the container -// for id in unschedulable_ids { -// container.push_id_into_queue(id); -// } -// -// // Push remaining transactions back into the container -// while let Some((id, _)) = prio_graph.pop_and_unblock() { -// container.push_id_into_queue(id); -// } -// -// assert_eq!( -// num_scheduled, num_sent, -// "number of scheduled and sent transactions must match" -// ); -// -// Ok(SchedulingSummary { -// num_scheduled, -// num_unschedulable, -// num_filtered_out, -// filter_time_us: total_filter_time_us, -// }) -// } -// -// /// Receive completed batches of transactions without blocking. -// /// Returns (num_transactions, num_retryable_transactions) on success. -// pub fn receive_completed( -// &mut self, -// container: &mut TransactionStateContainer

, -// ) -> Result<(usize, usize), SchedulerError> { -// let mut total_num_transactions: usize = 0; -// let mut total_num_retryable: usize = 0; -// loop { -// let (num_transactions, num_retryable) = self.try_receive_completed(container)?; -// if num_transactions == 0 { -// break; -// } -// saturating_add_assign!(total_num_transactions, num_transactions); -// saturating_add_assign!(total_num_retryable, num_retryable); -// } -// Ok((total_num_transactions, total_num_retryable)) -// } -// -// /// Receive completed batches of transactions. -// /// Returns `Ok((num_transactions, num_retryable))` if a batch was received, `Ok((0, 0))` if no batch was received. -// fn try_receive_completed( -// &mut self, -// container: &mut TransactionStateContainer

, -// ) -> Result<(usize, usize), SchedulerError> { -// match self.finished_consume_work_receiver.try_recv() { -// Ok(FinishedConsumeWork { -// work: -// ConsumeWork { -// batch_id, -// ids, -// transactions, -// max_ages, -// }, -// retryable_indexes, -// }) => { -// let num_transactions = ids.len(); -// let num_retryable = retryable_indexes.len(); -// -// // Free the locks -// self.complete_batch(batch_id, &transactions); -// -// // Retryable transactions should be inserted back into the container -// let mut retryable_iter = retryable_indexes.into_iter().peekable(); -// for (index, (id, transaction, max_age)) in -// izip!(ids, transactions, max_ages).enumerate() -// { -// if let Some(retryable_index) = retryable_iter.peek() { -// if *retryable_index == index { -// container.retry_transaction( -// id, -// SanitizedTransactionTTL { -// transaction, -// max_age, -// }, -// ); -// retryable_iter.next(); -// continue; -// } -// } -// container.remove_by_id(&id); -// } -// -// Ok((num_transactions, num_retryable)) -// } -// Err(TryRecvError::Empty) => Ok((0, 0)), -// Err(TryRecvError::Disconnected) => Err(SchedulerError::DisconnectedRecvChannel( -// "finished consume work", -// )), -// } -// } -// -// /// Mark a given `TransactionBatchId` as completed. -// /// This will update the internal tracking, including account locks. -// fn complete_batch( -// &mut self, -// batch_id: TransactionBatchId, -// transactions: &[SanitizedTransaction], -// ) { -// let thread_id = self.in_flight_tracker.complete_batch(batch_id); -// for transaction in transactions { -// let message = transaction.message(); -// let account_keys = message.account_keys(); -// let write_account_locks = account_keys -// .iter() -// .enumerate() -// .filter_map(|(index, key)| message.is_writable(index).then_some(key)); -// let read_account_locks = account_keys -// .iter() -// .enumerate() -// .filter_map(|(index, key)| (!message.is_writable(index)).then_some(key)); -// self.account_locks -// .unlock_accounts(write_account_locks, read_account_locks, thread_id); -// } -// } -// -// /// Send all batches of transactions to the worker threads. -// /// Returns the number of transactions sent. -// fn send_batches(&mut self, batches: &mut Batches) -> Result { -// (0..self.consume_work_senders.len()) -// .map(|thread_index| self.send_batch(batches, thread_index)) -// .sum() -// } -// -// /// Send a batch of transactions to the given thread's `ConsumeWork` channel. -// /// Returns the number of transactions sent. -// fn send_batch( -// &mut self, -// batches: &mut Batches, -// thread_index: usize, -// ) -> Result { -// if batches.ids[thread_index].is_empty() { -// return Ok(0); -// } -// -// let (ids, transactions, max_ages, total_cus) = batches.take_batch(thread_index); -// -// let batch_id = self -// .in_flight_tracker -// .track_batch(ids.len(), total_cus, thread_index); -// -// let num_scheduled = ids.len(); -// let work = ConsumeWork { -// batch_id, -// ids, -// transactions, -// max_ages, -// }; -// self.consume_work_senders[thread_index] -// .send(work) -// .map_err(|_| SchedulerError::DisconnectedSendChannel("consume work sender"))?; -// -// Ok(num_scheduled) -// } -// -// /// Given the schedulable `thread_set`, select the thread with the least amount -// /// of work queued up. -// /// Currently, "work" is just defined as the number of transactions. -// /// -// /// If the `chain_thread` is available, this thread will be selected, regardless of -// /// load-balancing. -// /// -// /// Panics if the `thread_set` is empty. This should never happen, see comment -// /// on `ThreadAwareAccountLocks::try_lock_accounts`. -// fn select_thread( -// thread_set: ThreadSet, -// batch_cus_per_thread: &[u64], -// in_flight_cus_per_thread: &[u64], -// batches_per_thread: &[Vec], -// in_flight_per_thread: &[usize], -// ) -> ThreadId { -// thread_set -// .contained_threads_iter() -// .map(|thread_id| { -// ( -// thread_id, -// batch_cus_per_thread[thread_id] + in_flight_cus_per_thread[thread_id], -// batches_per_thread[thread_id].len() + in_flight_per_thread[thread_id], -// ) -// }) -// .min_by(|a, b| a.1.cmp(&b.1).then_with(|| a.2.cmp(&b.2))) -// .map(|(thread_id, _, _)| thread_id) -// .unwrap() -// } -// -// /// Gets accessed accounts (resources) for use in `PrioGraph`. -// fn get_transaction_account_access( -// transaction: &SanitizedTransactionTTL, -// ) -> impl Iterator + '_ { -// let message = transaction.transaction.message(); -// message -// .account_keys() -// .iter() -// .enumerate() -// .map(|(index, key)| { -// if message.is_writable(index) { -// (*key, AccessKind::Write) -// } else { -// (*key, AccessKind::Read) -// } -// }) -// } -// } -// -// /// Metrics from scheduling transactions. -// #[derive(Debug, PartialEq, Eq)] -// pub struct SchedulingSummary { -// /// Number of transactions scheduled. -// pub num_scheduled: usize, -// /// Number of transactions that were not scheduled due to conflicts. -// pub num_unschedulable: usize, -// /// Number of transactions that were dropped due to filter. -// pub num_filtered_out: usize, -// /// Time spent filtering transactions -// pub filter_time_us: u64, -// } -// -// struct Batches { -// ids: Vec>, -// transactions: Vec>, -// max_ages: Vec>, -// total_cus: Vec, -// } -// -// impl Batches { -// fn new(num_threads: usize) -> Self { -// Self { -// ids: vec![Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH); num_threads], -// transactions: vec![Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH); num_threads], -// max_ages: vec![Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH); num_threads], -// total_cus: vec![0; num_threads], -// } -// } -// -// fn take_batch( -// &mut self, -// thread_id: ThreadId, -// ) -> ( -// Vec, -// Vec, -// Vec, -// u64, -// ) { -// ( -// core::mem::replace( -// &mut self.ids[thread_id], -// Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH), -// ), -// core::mem::replace( -// &mut self.transactions[thread_id], -// Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH), -// ), -// core::mem::replace( -// &mut self.max_ages[thread_id], -// Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH), -// ), -// core::mem::replace(&mut self.total_cus[thread_id], 0), -// ) -// } -// } -// -// /// A transaction has been scheduled to a thread. -// struct TransactionSchedulingInfo { -// thread_id: ThreadId, -// transaction: SanitizedTransaction, -// max_age: MaxAge, -// cost: u64, -// } -// -// /// Error type for reasons a transaction could not be scheduled. -// enum TransactionSchedulingError { -// /// Transaction was filtered out before locking. -// Filtered, -// /// Transaction cannot be scheduled due to conflicts, or -// /// higher priority conflicting transactions are unschedulable. -// UnschedulableConflicts, -// } -// -// fn try_schedule_transaction( -// transaction_state: &mut TransactionState

, -// pre_lock_filter: impl Fn(&SanitizedTransaction) -> bool, -// blocking_locks: &mut ReadWriteAccountSet, -// account_locks: &mut ThreadAwareAccountLocks, -// num_threads: usize, -// thread_selector: impl Fn(ThreadSet) -> ThreadId, -// ) -> Result { -// let transaction = &transaction_state.transaction_ttl().transaction; -// if !pre_lock_filter(transaction) { -// return Err(TransactionSchedulingError::Filtered); -// } -// -// // Check if this transaction conflicts with any blocked transactions -// let message = transaction.message(); -// if !blocking_locks.check_locks(message) { -// blocking_locks.take_locks(message); -// return Err(TransactionSchedulingError::UnschedulableConflicts); -// } -// -// // Schedule the transaction if it can be. -// let message = transaction.message(); -// let account_keys = message.account_keys(); -// let write_account_locks = account_keys -// .iter() -// .enumerate() -// .filter_map(|(index, key)| message.is_writable(index).then_some(key)) -// .collect::>(); -// let read_account_locks = account_keys -// .iter() -// .enumerate() -// .filter_map(|(index, key)| (!message.is_writable(index)).then_some(key)) -// .collect::>(); -// -// let Some(thread_id) = account_locks.try_lock_accounts( -// write_account_locks.into_iter(), -// read_account_locks.into_iter(), -// ThreadSet::any(num_threads), -// thread_selector, -// ) else { -// blocking_locks.take_locks(message); -// return Err(TransactionSchedulingError::UnschedulableConflicts); -// }; -// -// let sanitized_transaction_ttl = transaction_state.transition_to_pending(); -// let cost = transaction_state.cost(); -// -// Ok(TransactionSchedulingInfo { -// thread_id, -// transaction: sanitized_transaction_ttl.transaction, -// max_age: sanitized_transaction_ttl.max_age, -// cost, -// }) -// } -// -// #[cfg(test)] -// mod tests { -// use { -// super::*, -// crate::tests::MockImmutableDeserializedPacket, -// crate::TARGET_NUM_TRANSACTIONS_PER_BATCH, -// crossbeam_channel::{unbounded, Receiver}, -// itertools::Itertools, -// solana_sdk::{ -// clock::Slot, compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, -// packet::Packet, pubkey::Pubkey, signature::Keypair, signer::Signer, system_instruction, -// transaction::Transaction, -// }, -// std::{borrow::Borrow, sync::Arc}, -// }; -// -// macro_rules! txid { -// ($value:expr) => { -// TransactionId::new($value) -// }; -// } -// -// macro_rules! txids { -// ([$($element:expr),*]) => { -// vec![ $(txid!($element)),* ] -// }; -// } -// -// fn create_test_frame( -// num_threads: usize, -// ) -> ( -// PrioGraphScheduler, -// Vec>, -// Sender, -// ) { -// let (consume_work_senders, consume_work_receivers) = -// (0..num_threads).map(|_| unbounded()).unzip(); -// let (finished_consume_work_sender, finished_consume_work_receiver) = unbounded(); -// let scheduler = PrioGraphScheduler::::new( -// consume_work_senders, -// finished_consume_work_receiver, -// ); -// ( -// scheduler, -// consume_work_receivers, -// finished_consume_work_sender, -// ) -// } -// -// fn prioritized_tranfers( -// from_keypair: &Keypair, -// to_pubkeys: impl IntoIterator>, -// lamports: u64, -// priority: u64, -// ) -> SanitizedTransaction { -// let to_pubkeys_lamports = to_pubkeys -// .into_iter() -// .map(|pubkey| *pubkey.borrow()) -// .zip(std::iter::repeat(lamports)) -// .collect_vec(); -// let mut ixs = -// system_instruction::transfer_many(&from_keypair.pubkey(), &to_pubkeys_lamports); -// let prioritization = ComputeBudgetInstruction::set_compute_unit_price(priority); -// ixs.push(prioritization); -// let message = Message::new(&ixs, Some(&from_keypair.pubkey())); -// let tx = Transaction::new(&[from_keypair], message, Hash::default()); -// SanitizedTransaction::from_transaction_for_tests(tx) -// } -// -// fn create_container( -// tx_infos: impl IntoIterator< -// Item = ( -// impl Borrow, -// impl IntoIterator>, -// u64, -// u64, -// ), -// >, -// ) -> TransactionStateContainer { -// let mut container = -// TransactionStateContainer::::with_capacity(10 * 1024); -// for (index, (from_keypair, to_pubkeys, lamports, compute_unit_price)) in -// tx_infos.into_iter().enumerate() -// { -// let id = TransactionId::new(index as u64); -// let transaction = prioritized_tranfers( -// from_keypair.borrow(), -// to_pubkeys, -// lamports, -// compute_unit_price, -// ); -// let packet = Arc::new( -// MockImmutableDeserializedPacket::new( -// Packet::from_data(None, transaction.to_versioned_transaction()).unwrap(), -// ) -// .unwrap(), -// ); -// let transaction_ttl = SanitizedTransactionTTL { -// transaction, -// max_age: MaxAge { -// epoch_invalidation_slot: Slot::MAX, -// alt_invalidation_slot: Slot::MAX, -// }, -// }; -// const TEST_TRANSACTION_COST: u64 = 5000; -// container.insert_new_transaction( -// id, -// transaction_ttl, -// packet, -// compute_unit_price, -// TEST_TRANSACTION_COST, -// ); -// } -// -// container -// } -// -// fn collect_work( -// receiver: &Receiver, -// ) -> (Vec, Vec>) { -// receiver -// .try_iter() -// .map(|work| { -// let ids = work.ids.clone(); -// (work, ids) -// }) -// .unzip() -// } -// -// fn test_pre_graph_filter(_txs: &[&SanitizedTransaction], results: &mut [bool]) { -// results.fill(true); -// } -// -// fn test_pre_lock_filter(_tx: &SanitizedTransaction) -> bool { -// true -// } -// -// #[test] -// fn test_schedule_disconnected_channel() { -// let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); -// let mut container = create_container([(&Keypair::new(), &[Pubkey::new_unique()], 1, 1)]); -// -// drop(work_receivers); // explicitly drop receivers -// assert_matches!( -// scheduler.schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter), -// Err(SchedulerError::DisconnectedSendChannel(_)) -// ); -// } -// -// #[test] -// fn test_schedule_single_threaded_no_conflicts() { -// let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); -// let mut container = create_container([ -// (&Keypair::new(), &[Pubkey::new_unique()], 1, 1), -// (&Keypair::new(), &[Pubkey::new_unique()], 2, 2), -// ]); -// -// let scheduling_summary = scheduler -// .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) -// .unwrap(); -// assert_eq!(scheduling_summary.num_scheduled, 2); -// assert_eq!(scheduling_summary.num_unschedulable, 0); -// assert_eq!(collect_work(&work_receivers[0]).1, vec![txids!([1, 0])]); -// } -// -// #[test] -// fn test_schedule_single_threaded_conflict() { -// let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); -// let pubkey = Pubkey::new_unique(); -// let mut container = create_container([ -// (&Keypair::new(), &[pubkey], 1, 1), -// (&Keypair::new(), &[pubkey], 1, 2), -// ]); -// -// let scheduling_summary = scheduler -// .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) -// .unwrap(); -// assert_eq!(scheduling_summary.num_scheduled, 2); -// assert_eq!(scheduling_summary.num_unschedulable, 0); -// assert_eq!( -// collect_work(&work_receivers[0]).1, -// vec![txids!([1]), txids!([0])] -// ); -// } -// -// #[test] -// fn test_schedule_consume_single_threaded_multi_batch() { -// let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); -// let mut container = create_container( -// (0..4 * TARGET_NUM_TRANSACTIONS_PER_BATCH) -// .map(|i| (Keypair::new(), [Pubkey::new_unique()], i as u64, 1)), -// ); -// -// // expect 4 full batches to be scheduled -// let scheduling_summary = scheduler -// .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) -// .unwrap(); -// assert_eq!( -// scheduling_summary.num_scheduled, -// 4 * TARGET_NUM_TRANSACTIONS_PER_BATCH -// ); -// assert_eq!(scheduling_summary.num_unschedulable, 0); -// -// let thread0_work_counts: Vec<_> = work_receivers[0] -// .try_iter() -// .map(|work| work.ids.len()) -// .collect(); -// assert_eq!(thread0_work_counts, [TARGET_NUM_TRANSACTIONS_PER_BATCH; 4]); -// } -// -// #[test] -// fn test_schedule_simple_thread_selection() { -// let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(2); -// let mut container = -// create_container((0..4).map(|i| (Keypair::new(), [Pubkey::new_unique()], 1, i))); -// -// let scheduling_summary = scheduler -// .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) -// .unwrap(); -// assert_eq!(scheduling_summary.num_scheduled, 4); -// assert_eq!(scheduling_summary.num_unschedulable, 0); -// assert_eq!(collect_work(&work_receivers[0]).1, [txids!([3, 1])]); -// assert_eq!(collect_work(&work_receivers[1]).1, [txids!([2, 0])]); -// } -// -// #[test] -// fn test_schedule_priority_guard() { -// let (mut scheduler, work_receivers, finished_work_sender) = create_test_frame(2); -// // intentionally shorten the look-ahead window to cause unschedulable conflicts -// scheduler.look_ahead_window_size = 2; -// -// let accounts = (0..8).map(|_| Keypair::new()).collect_vec(); -// let mut container = create_container([ -// (&accounts[0], &[accounts[1].pubkey()], 1, 6), -// (&accounts[2], &[accounts[3].pubkey()], 1, 5), -// (&accounts[4], &[accounts[5].pubkey()], 1, 4), -// (&accounts[6], &[accounts[7].pubkey()], 1, 3), -// (&accounts[1], &[accounts[2].pubkey()], 1, 2), -// (&accounts[2], &[accounts[3].pubkey()], 1, 1), -// ]); -// -// // The look-ahead window is intentionally shortened, high priority transactions -// // [0, 1, 2, 3] do not conflict, and are scheduled onto threads in a -// // round-robin fashion. This leads to transaction [4] being unschedulable due -// // to conflicts with [0] and [1], which were scheduled to different threads. -// // Transaction [5] is technically schedulable, onto thread 1 since it only -// // conflicts with transaction [1]. However, [5] will not be scheduled because -// // it conflicts with a higher-priority transaction [4] that is unschedulable. -// // The full prio-graph can be visualized as: -// // [0] \ -// // -> [4] -> [5] -// // [1] / ------/ -// // [2] -// // [3] -// // Because the look-ahead window is shortened to a size of 4, the scheduler does -// // not have knowledge of the joining at transaction [4] until after [0] and [1] -// // have been scheduled. -// let scheduling_summary = scheduler -// .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) -// .unwrap(); -// assert_eq!(scheduling_summary.num_scheduled, 4); -// assert_eq!(scheduling_summary.num_unschedulable, 2); -// let (thread_0_work, thread_0_ids) = collect_work(&work_receivers[0]); -// assert_eq!(thread_0_ids, [txids!([0]), txids!([2])]); -// assert_eq!( -// collect_work(&work_receivers[1]).1, -// [txids!([1]), txids!([3])] -// ); -// -// // Cannot schedule even on next pass because of lock conflicts -// let scheduling_summary = scheduler -// .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) -// .unwrap(); -// assert_eq!(scheduling_summary.num_scheduled, 0); -// assert_eq!(scheduling_summary.num_unschedulable, 2); -// -// // Complete batch on thread 0. Remaining txs can be scheduled onto thread 1 -// finished_work_sender -// .send(FinishedConsumeWork { -// work: thread_0_work.into_iter().next().unwrap(), -// retryable_indexes: vec![], -// }) -// .unwrap(); -// scheduler.receive_completed(&mut container).unwrap(); -// let scheduling_summary = scheduler -// .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) -// .unwrap(); -// assert_eq!(scheduling_summary.num_scheduled, 2); -// assert_eq!(scheduling_summary.num_unschedulable, 0); -// -// assert_eq!( -// collect_work(&work_receivers[1]).1, -// [txids!([4]), txids!([5])] -// ); -// } -// -// #[test] -// fn test_schedule_pre_lock_filter() { -// let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); -// let pubkey = Pubkey::new_unique(); -// let keypair = Keypair::new(); -// let mut container = create_container([ -// (&Keypair::new(), &[pubkey], 1, 1), -// (&keypair, &[pubkey], 1, 2), -// (&Keypair::new(), &[pubkey], 1, 3), -// ]); -// -// // 2nd transaction should be filtered out and dropped before locking. -// let pre_lock_filter = -// |tx: &SanitizedTransaction| tx.message().fee_payer() != &keypair.pubkey(); -// let scheduling_summary = scheduler -// .schedule(&mut container, test_pre_graph_filter, pre_lock_filter) -// .unwrap(); -// assert_eq!(scheduling_summary.num_scheduled, 2); -// assert_eq!(scheduling_summary.num_unschedulable, 0); -// assert_eq!( -// collect_work(&work_receivers[0]).1, -// vec![txids!([2]), txids!([0])] -// ); -// } -// } diff --git a/scheduler/src/impls/prio_graph_scheduler/read_write_account_set.rs b/scheduler/src/impls/prio_graph_scheduler/read_write_account_set.rs index 4e23919..0e70837 100644 --- a/scheduler/src/impls/prio_graph_scheduler/read_write_account_set.rs +++ b/scheduler/src/impls/prio_graph_scheduler/read_write_account_set.rs @@ -144,7 +144,7 @@ mod tests { bank, bank.get_reserved_account_keys(), ) - .unwrap() + .unwrap() } fn create_test_address_lookup_table( diff --git a/prio-graph-scheduler/src/prio_graph_scheduler.rs b/scheduler/src/impls/prio_graph_scheduler/scheduler.rs similarity index 92% rename from prio-graph-scheduler/src/prio_graph_scheduler.rs rename to scheduler/src/impls/prio_graph_scheduler/scheduler.rs index bca3252..1c2e9d8 100644 --- a/prio-graph-scheduler/src/prio_graph_scheduler.rs +++ b/scheduler/src/impls/prio_graph_scheduler/scheduler.rs @@ -1,17 +1,21 @@ +use crate::scheduler_messages::MaxAge; use { + super::{ + transaction_state::{SanitizedTransactionTTL, TransactionState}, + transaction_state_container::TransactionStateContainer, + }, crate::{ - deserializable_packet::DeserializableTxPacket, - in_flight_tracker::InFlightTracker, - read_write_account_set::ReadWriteAccountSet, - scheduler_error::SchedulerError, + impls::prio_graph_scheduler::in_flight_tracker::InFlightTracker, + impls::prio_graph_scheduler::read_write_account_set::ReadWriteAccountSet, + impls::prio_graph_scheduler::scheduler_error::SchedulerError, + impls::prio_graph_scheduler::thread_aware_account_locks::{ + ThreadAwareAccountLocks, ThreadId, ThreadSet, + }, + impls::prio_graph_scheduler::transaction_priority_id::TransactionPriorityId, + impls::prio_graph_scheduler::TARGET_NUM_TRANSACTIONS_PER_BATCH, scheduler_messages::{ - ConsumeWork, FinishedConsumeWork, MaxAge, TransactionBatchId, TransactionId, + SchedulingBatch, SchedulingBatchResult, TransactionBatchId, TransactionId, }, - thread_aware_account_locks::{ThreadAwareAccountLocks, ThreadId, ThreadSet}, - transaction_priority_id::TransactionPriorityId, - transaction_state::{SanitizedTransactionTTL, TransactionState}, - transaction_state_container::TransactionStateContainer, - TARGET_NUM_TRANSACTIONS_PER_BATCH, }, crossbeam_channel::{Receiver, Sender, TryRecvError}, itertools::izip, @@ -21,19 +25,18 @@ use { solana_sdk::{pubkey::Pubkey, saturating_add_assign, transaction::SanitizedTransaction}, }; -pub struct PrioGraphScheduler { +pub struct PrioGraphScheduler { in_flight_tracker: InFlightTracker, account_locks: ThreadAwareAccountLocks, - consume_work_senders: Vec>, - finished_consume_work_receiver: Receiver, + consume_work_senders: Vec>, + finished_consume_work_receiver: Receiver, look_ahead_window_size: usize, - phantom: std::marker::PhantomData

, } -impl PrioGraphScheduler

{ +impl PrioGraphScheduler { pub fn new( - consume_work_senders: Vec>, - finished_consume_work_receiver: Receiver, + consume_work_senders: Vec>, + finished_consume_work_receiver: Receiver, ) -> Self { let num_threads = consume_work_senders.len(); Self { @@ -42,7 +45,6 @@ impl PrioGraphScheduler

{ consume_work_senders, finished_consume_work_receiver, look_ahead_window_size: 2048, - phantom: std::marker::PhantomData, } } @@ -64,7 +66,7 @@ impl PrioGraphScheduler

{ /// not cause conflicts in the near future. pub fn schedule( &mut self, - container: &mut TransactionStateContainer

, + container: &mut TransactionStateContainer, pre_graph_filter: impl Fn(&[&SanitizedTransaction], &mut [bool]), pre_lock_filter: impl Fn(&SanitizedTransaction) -> bool, ) -> Result { @@ -100,7 +102,7 @@ impl PrioGraphScheduler

{ let mut total_filter_time_us: u64 = 0; let mut window_budget = self.look_ahead_window_size; - let mut chunked_pops = |container: &mut TransactionStateContainer

, + let mut chunked_pops = |container: &mut TransactionStateContainer, prio_graph: &mut PrioGraph<_, _, _, _>, window_budget: &mut usize| { while *window_budget > 0 { @@ -279,7 +281,7 @@ impl PrioGraphScheduler

{ /// Returns (num_transactions, num_retryable_transactions) on success. pub fn receive_completed( &mut self, - container: &mut TransactionStateContainer

, + container: &mut TransactionStateContainer, ) -> Result<(usize, usize), SchedulerError> { let mut total_num_transactions: usize = 0; let mut total_num_retryable: usize = 0; @@ -298,12 +300,12 @@ impl PrioGraphScheduler

{ /// Returns `Ok((num_transactions, num_retryable))` if a batch was received, `Ok((0, 0))` if no batch was received. fn try_receive_completed( &mut self, - container: &mut TransactionStateContainer

, + container: &mut TransactionStateContainer, ) -> Result<(usize, usize), SchedulerError> { match self.finished_consume_work_receiver.try_recv() { - Ok(FinishedConsumeWork { - work: - ConsumeWork { + Ok(SchedulingBatchResult { + batch: + SchedulingBatch { batch_id, ids, transactions, @@ -397,7 +399,7 @@ impl PrioGraphScheduler

{ .track_batch(ids.len(), total_cus, thread_index); let num_scheduled = ids.len(); - let work = ConsumeWork { + let work = SchedulingBatch { batch_id, ids, transactions, @@ -533,8 +535,8 @@ enum TransactionSchedulingError { UnschedulableConflicts, } -fn try_schedule_transaction( - transaction_state: &mut TransactionState

, +fn try_schedule_transaction( + transaction_state: &mut TransactionState, pre_lock_filter: impl Fn(&SanitizedTransaction) -> bool, blocking_locks: &mut ReadWriteAccountSet, account_locks: &mut ThreadAwareAccountLocks, @@ -590,18 +592,17 @@ fn try_schedule_transaction( #[cfg(test)] mod tests { + use assert_matches::assert_matches; use { super::*, - crate::tests::MockImmutableDeserializedPacket, - crate::TARGET_NUM_TRANSACTIONS_PER_BATCH, crossbeam_channel::{unbounded, Receiver}, itertools::Itertools, solana_sdk::{ clock::Slot, compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, - packet::Packet, pubkey::Pubkey, signature::Keypair, signer::Signer, system_instruction, + pubkey::Pubkey, signature::Keypair, signer::Signer, system_instruction, transaction::Transaction, }, - std::{borrow::Borrow, sync::Arc}, + std::borrow::Borrow, }; macro_rules! txid { @@ -619,17 +620,15 @@ mod tests { fn create_test_frame( num_threads: usize, ) -> ( - PrioGraphScheduler, - Vec>, - Sender, + PrioGraphScheduler, + Vec>, + Sender, ) { let (consume_work_senders, consume_work_receivers) = (0..num_threads).map(|_| unbounded()).unzip(); let (finished_consume_work_sender, finished_consume_work_receiver) = unbounded(); - let scheduler = PrioGraphScheduler::::new( - consume_work_senders, - finished_consume_work_receiver, - ); + let scheduler = + PrioGraphScheduler::new(consume_work_senders, finished_consume_work_receiver); ( scheduler, consume_work_receivers, @@ -666,9 +665,8 @@ mod tests { u64, ), >, - ) -> TransactionStateContainer { - let mut container = - TransactionStateContainer::::with_capacity(10 * 1024); + ) -> TransactionStateContainer { + let mut container = TransactionStateContainer::with_capacity(10 * 1024); for (index, (from_keypair, to_pubkeys, lamports, compute_unit_price)) in tx_infos.into_iter().enumerate() { @@ -679,12 +677,7 @@ mod tests { lamports, compute_unit_price, ); - let packet = Arc::new( - MockImmutableDeserializedPacket::new( - Packet::from_data(None, transaction.to_versioned_transaction()).unwrap(), - ) - .unwrap(), - ); + let transaction_ttl = SanitizedTransactionTTL { transaction, max_age: MaxAge { @@ -696,7 +689,6 @@ mod tests { container.insert_new_transaction( id, transaction_ttl, - packet, compute_unit_price, TEST_TRANSACTION_COST, ); @@ -706,8 +698,8 @@ mod tests { } fn collect_work( - receiver: &Receiver, - ) -> (Vec, Vec>) { + receiver: &Receiver, + ) -> (Vec, Vec>) { receiver .try_iter() .map(|work| { @@ -866,8 +858,8 @@ mod tests { // Complete batch on thread 0. Remaining txs can be scheduled onto thread 1 finished_work_sender - .send(FinishedConsumeWork { - work: thread_0_work.into_iter().next().unwrap(), + .send(SchedulingBatchResult { + batch: thread_0_work.into_iter().next().unwrap(), retryable_indexes: vec![], }) .unwrap(); diff --git a/scheduler/src/impls/prio_graph_scheduler/thread_aware_account_locks.rs b/scheduler/src/impls/prio_graph_scheduler/thread_aware_account_locks.rs index d8563c6..6d9d4c2 100644 --- a/scheduler/src/impls/prio_graph_scheduler/thread_aware_account_locks.rs +++ b/scheduler/src/impls/prio_graph_scheduler/thread_aware_account_locks.rs @@ -154,9 +154,9 @@ impl ThreadAwareAccountLocks { match self.locks.get(account) { None => ThreadSet::any(self.num_threads), Some(AccountLocks { - write_locks: None, - read_locks: Some(read_locks), - }) => { + write_locks: None, + read_locks: Some(read_locks), + }) => { if WRITE { read_locks .thread_set @@ -168,13 +168,13 @@ impl ThreadAwareAccountLocks { } } Some(AccountLocks { - write_locks: Some(write_locks), - read_locks: None, - }) => ThreadSet::only(write_locks.thread_id), + write_locks: Some(write_locks), + read_locks: None, + }) => ThreadSet::only(write_locks.thread_id), Some(AccountLocks { - write_locks: Some(write_locks), - read_locks: Some(read_locks), - }) => { + write_locks: Some(write_locks), + read_locks: Some(read_locks), + }) => { assert_eq!( read_locks.thread_set.only_one_contained(), Some(write_locks.thread_id) @@ -182,9 +182,9 @@ impl ThreadAwareAccountLocks { read_locks.thread_set } Some(AccountLocks { - write_locks: None, - read_locks: None, - }) => unreachable!(), + write_locks: None, + read_locks: None, + }) => unreachable!(), } } diff --git a/scheduler/src/impls/prio_graph_scheduler/transaction_state.rs b/scheduler/src/impls/prio_graph_scheduler/transaction_state.rs index be6bdc6..76b4328 100644 --- a/scheduler/src/impls/prio_graph_scheduler/transaction_state.rs +++ b/scheduler/src/impls/prio_graph_scheduler/transaction_state.rs @@ -1,5 +1,5 @@ use crate::scheduler_messages::MaxAge; -use {solana_sdk::transaction::SanitizedTransaction, std::sync::Arc}; +use solana_sdk::transaction::SanitizedTransaction; /// Simple wrapper type to tie a sanitized transaction to max age slot. #[derive(Clone, Debug)] @@ -114,11 +114,7 @@ impl TransactionState { pub fn transition_to_unprocessed(&mut self, transaction_ttl: SanitizedTransactionTTL) { match self.take() { TransactionState::Unprocessed { .. } => panic!("already unprocessed"), - TransactionState::Pending { - transaction_ttl, - priority, - cost, - } => { + TransactionState::Pending { priority, cost, .. } => { *self = Self::Unprocessed { transaction_ttl, priority, diff --git a/scheduler/src/impls/prio_graph_scheduler/transaction_state_container.rs b/scheduler/src/impls/prio_graph_scheduler/transaction_state_container.rs index 17d80eb..9bbaa1c 100644 --- a/scheduler/src/impls/prio_graph_scheduler/transaction_state_container.rs +++ b/scheduler/src/impls/prio_graph_scheduler/transaction_state_container.rs @@ -6,7 +6,7 @@ use { crate::scheduler_messages::TransactionId, itertools::MinMaxResult, min_max_heap::MinMaxHeap, - std::{collections::HashMap, sync::Arc}, + std::collections::HashMap, }; /// This structure will hold `TransactionState` for the entirety of a @@ -47,6 +47,10 @@ impl TransactionStateContainer { } } + pub fn len(&self) -> usize { + self.priority_queue.len() + } + /// Returns true if the queue is empty. pub fn is_empty(&self) -> bool { self.priority_queue.is_empty() @@ -160,13 +164,7 @@ mod tests { }; /// Returns (transaction_ttl, priority, cost) - fn test_transaction( - priority: u64, - ) -> ( - SanitizedTransactionTTL, - u64, - u64, - ) { + fn test_transaction(priority: u64) -> (SanitizedTransactionTTL, u64, u64) { let from_keypair = Keypair::new(); let ixs = vec![ system_instruction::transfer( @@ -193,10 +191,7 @@ mod tests { (transaction_ttl, priority, TEST_TRANSACTION_COST) } - fn push_to_container( - container: &mut TransactionStateContainer, - num: usize, - ) { + fn push_to_container(container: &mut TransactionStateContainer, num: usize) { for id in 0..num as u64 { let priority = id; let (transaction_ttl, priority, cost) = test_transaction(priority); diff --git a/scheduler/src/lib.rs b/scheduler/src/lib.rs index 0739c71..b2faedf 100644 --- a/scheduler/src/lib.rs +++ b/scheduler/src/lib.rs @@ -1,6 +1,6 @@ +pub mod id_generator; pub mod impls; pub mod scheduler; pub mod scheduler_messages; pub mod status_slicing; pub mod stopwatch; -pub mod id_generator; diff --git a/scheduler/src/scheduler.rs b/scheduler/src/scheduler.rs index a15a7fe..e5bcf95 100644 --- a/scheduler/src/scheduler.rs +++ b/scheduler/src/scheduler.rs @@ -1,3 +1,4 @@ +use crate::impls::prio_graph_scheduler::scheduler_error::SchedulerError; use crate::scheduler_messages::{SchedulingBatch, SchedulingBatchResult}; use crossbeam_channel::{Receiver, Sender}; @@ -11,13 +12,16 @@ use crossbeam_channel::{Receiver, Sender}; /// Scheduler -- Task channelK -> [workerK] -> Task finish callback -> Scheduler /// | ... | /// -> Task channelN -> [workerN] -> Task finish callback -> +/// +/// so there should be a scheduler thread, accepting upstreaming transaction flow from rpc, +/// going with a loop and calling scheduler repeatedly with `schedule_batch` and `receive_complete`. pub trait Scheduler { fn new( schedule_task_senders: Vec>, task_finished_receivers: Receiver, ) -> Self; - fn schedule_batch(&mut self, txs: SchedulingBatch); + fn schedule_batch(&mut self, txs: SchedulingBatch) -> Result<(), SchedulerError>; - fn receive_complete(&mut self, receipt: SchedulingBatchResult); + fn receive_complete(&mut self) -> Result<(), SchedulerError>; } diff --git a/scheduler/src/scheduler_messages.rs b/scheduler/src/scheduler_messages.rs index 8c2572e..2bcace6 100644 --- a/scheduler/src/scheduler_messages.rs +++ b/scheduler/src/scheduler_messages.rs @@ -58,6 +58,13 @@ pub struct SchedulingBatch { pub batch_id: TransactionBatchId, pub ids: Vec, pub transactions: Vec, + pub max_ages: Vec, +} + +impl SchedulingBatch { + pub fn valid(&self) -> bool { + self.transactions.len() == self.ids.len() && self.ids.len() == self.max_ages.len() + } } /// The scheduling result from worker one time. @@ -71,10 +78,9 @@ pub struct SchedulingBatchResult { pub retryable_indexes: Vec, } - /// A TTL flag for a transaction. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)] pub struct MaxAge { pub epoch_invalidation_slot: Slot, pub alt_invalidation_slot: Slot, -} \ No newline at end of file +}