Skip to content

Commit

Permalink
feat: prio-graph crate remove dependency of solana-core
Browse files Browse the repository at this point in the history
  • Loading branch information
flame4 committed Oct 14, 2024
1 parent 32ce898 commit d554cec
Show file tree
Hide file tree
Showing 8 changed files with 507 additions and 75 deletions.
9 changes: 8 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 16 additions & 1 deletion prio-graph-scheduler/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ license.workspace = true
edition.workspace = true

[dependencies]
solana-core = { workspace = true }
solana-sdk = { workspace = true }
solana-poh = { workspace = true }
solana-metrics = { workspace = true }
Expand All @@ -31,9 +30,25 @@ min-max-heap = { workspace = true }

[dev-dependencies]
assert_matches = { workspace = true }
solana-compute-budget = { workspace = true }
solana-perf = { workspace = true }
solana-runtime-transaction = { workspace = true }
solana-sanitize = { workspace = true }
solana-short-vec = { workspace = true }
solana-svm-transaction = { workspace = true }
# let dev-context-only-utils works when running this crate.
solana-prio-graph-scheduler = { path = ".", features = [
"dev-context-only-utils",
] }
solana-sdk = { workspace = true, features = ["dev-context-only-utils"] }

bincode = { workspace = true }

[package.metadata.docs.rs]
targets = ["x86_64-unknown-linux-gnu"]

[lints]
workspace = true

[features]
dev-context-only-utils = ["solana-runtime/dev-context-only-utils"]
44 changes: 0 additions & 44 deletions prio-graph-scheduler/src/deserializable_packet.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::collections::HashSet;
use solana_core::banking_stage::immutable_deserialized_packet::{DeserializedPacketError, ImmutableDeserializedPacket};
use solana_sdk::hash::Hash;
use solana_sdk::message::AddressLoader;
use solana_sdk::packet::Packet;
Expand Down Expand Up @@ -37,47 +36,4 @@ pub trait DeserializableTxPacket: PartialEq + PartialOrd + Eq + Sized {
fn compute_unit_price(&self) -> u64;

fn compute_unit_limit(&self) -> u64;
}


/// TODO: migrate to solana_core
impl DeserializableTxPacket for ImmutableDeserializedPacket {
type DeserializeError = DeserializedPacketError;

fn from_packet(packet: Packet) -> Result<Self, Self::DeserializeError> {
ImmutableDeserializedPacket::new(packet)
}

fn build_sanitized_transaction(
&self,
votes_only: bool,
address_loader: impl AddressLoader,
reserved_account_keys: &HashSet<Pubkey>,
) -> Option<SanitizedTransaction> {
self.build_sanitized_transaction(votes_only, address_loader, reserved_account_keys)
}

fn original_packet(&self) -> &Packet {
&self.original_packet
}

fn transaction(&self) -> &SanitizedVersionedTransaction {
&self.transaction
}

fn message_hash(&self) -> &Hash {
&self.message_hash
}

fn is_simple_vote(&self) -> bool {
self.is_simple_vote
}

fn compute_unit_price(&self) -> u64 {
self.compute_unit_price
}

fn compute_unit_limit(&self) -> u64 {
u64::from(self.compute_unit_limit)
}
}
186 changes: 179 additions & 7 deletions prio-graph-scheduler/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,192 @@
//! Solana Priority Graph Scheduler.
pub mod transaction_state;
pub mod scheduler_messages;
pub mod id_generator;
pub mod in_flight_tracker;
pub mod thread_aware_account_locks;
pub mod transaction_priority_id;
pub mod scheduler_error;
pub mod scheduler_messages;
pub mod scheduler_metrics;
pub mod thread_aware_account_locks;
pub mod transaction_priority_id;
pub mod transaction_state;
// pub mod scheduler_controller;
pub mod transaction_state_container;
pub mod prio_graph_scheduler;
pub mod deserializable_packet;
pub mod prio_graph_scheduler;
pub mod transaction_state_container;

#[macro_use]
extern crate solana_metrics;

#[cfg(test)]
#[macro_use]
extern crate assert_matches;
extern crate assert_matches;

/// Consumer will create chunks of transactions from buffer with up to this size.
pub const TARGET_NUM_TRANSACTIONS_PER_BATCH: usize = 64;

mod read_write_account_set;

#[cfg(test)]
mod tests {
use {
crate::deserializable_packet::DeserializableTxPacket,
solana_compute_budget::compute_budget_limits::ComputeBudgetLimits,
solana_perf::packet::Packet,
solana_runtime_transaction::instructions_processor::process_compute_budget_instructions,
solana_sanitize::SanitizeError,
solana_sdk::{
hash::Hash,
message::Message,
pubkey::Pubkey,
signature::Signature,
transaction::{
AddressLoader, SanitizedTransaction, SanitizedVersionedTransaction,
VersionedTransaction,
},
},
solana_short_vec::decode_shortu16_len,
solana_svm_transaction::instruction::SVMInstruction,
std::{cmp::Ordering, collections::HashSet, mem::size_of},
thiserror::Error,
};

#[derive(Debug, Error)]
pub enum MockDeserializedPacketError {
#[error("ShortVec Failed to Deserialize")]
// short_vec::decode_shortu16_len() currently returns () on error
ShortVecError(()),
#[error("Deserialization Error: {0}")]
DeserializationError(#[from] bincode::Error),
#[error("overflowed on signature size {0}")]
SignatureOverflowed(usize),
#[error("packet failed sanitization {0}")]
SanitizeError(#[from] SanitizeError),
#[error("transaction failed prioritization")]
PrioritizationFailure,
}

#[derive(Debug, Eq)]
pub struct MockImmutableDeserializedPacket {
pub original_packet: Packet,
pub transaction: SanitizedVersionedTransaction,
pub message_hash: Hash,
pub is_simple_vote: bool,
pub compute_unit_price: u64,
pub compute_unit_limit: u32,
}

impl DeserializableTxPacket for MockImmutableDeserializedPacket {
type DeserializeError = MockDeserializedPacketError;
fn from_packet(packet: Packet) -> Result<Self, Self::DeserializeError> {
let versioned_transaction: VersionedTransaction = packet.deserialize_slice(..)?;
let sanitized_transaction =
SanitizedVersionedTransaction::try_from(versioned_transaction)?;
let message_bytes = packet_message(&packet)?;
let message_hash = Message::hash_raw_message(message_bytes);
let is_simple_vote = packet.meta().is_simple_vote_tx();

// drop transaction if prioritization fails.
let ComputeBudgetLimits {
mut compute_unit_price,
compute_unit_limit,
..
} = process_compute_budget_instructions(
sanitized_transaction
.get_message()
.program_instructions_iter()
.map(|(pubkey, ix)| (pubkey, SVMInstruction::from(ix))),
)
.map_err(|_| MockDeserializedPacketError::PrioritizationFailure)?;

// set compute unit price to zero for vote transactions
if is_simple_vote {
compute_unit_price = 0;
};

Ok(Self {
original_packet: packet,
transaction: sanitized_transaction,
message_hash,
is_simple_vote,
compute_unit_price,
compute_unit_limit,
})
}

fn original_packet(&self) -> &Packet {
&self.original_packet
}

fn transaction(&self) -> &SanitizedVersionedTransaction {
&self.transaction
}

fn message_hash(&self) -> &Hash {
&self.message_hash
}

fn is_simple_vote(&self) -> bool {
self.is_simple_vote
}

fn compute_unit_price(&self) -> u64 {
self.compute_unit_price
}

fn compute_unit_limit(&self) -> u64 {
u64::from(self.compute_unit_limit)
}

// This function deserializes packets into transactions, computes the blake3 hash of transaction
// messages.
fn build_sanitized_transaction(
&self,
votes_only: bool,
address_loader: impl AddressLoader,
reserved_account_keys: &HashSet<Pubkey>,
) -> Option<SanitizedTransaction> {
if votes_only && !self.is_simple_vote() {
return None;
}
let tx = SanitizedTransaction::try_new(
self.transaction().clone(),
*self.message_hash(),
self.is_simple_vote(),
address_loader,
reserved_account_keys,
)
.ok()?;
Some(tx)
}
}

// PartialEq MUST be consistent with PartialOrd and Ord
impl PartialEq for MockImmutableDeserializedPacket {
fn eq(&self, other: &Self) -> bool {
self.compute_unit_price() == other.compute_unit_price()
}
}

impl PartialOrd for MockImmutableDeserializedPacket {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl Ord for MockImmutableDeserializedPacket {
fn cmp(&self, other: &Self) -> Ordering {
self.compute_unit_price().cmp(&other.compute_unit_price())
}
}

/// Read the transaction message from packet data
fn packet_message(packet: &Packet) -> Result<&[u8], MockDeserializedPacketError> {
let (sig_len, sig_size) = packet
.data(..)
.and_then(|bytes| decode_shortu16_len(bytes).ok())
.ok_or(MockDeserializedPacketError::ShortVecError(()))?;
sig_len
.checked_mul(size_of::<Signature>())
.and_then(|v| v.checked_add(sig_size))
.and_then(|msg_start| packet.data(msg_start..))
.ok_or(MockDeserializedPacketError::SignatureOverflowed(sig_size))
}
}
21 changes: 9 additions & 12 deletions prio-graph-scheduler/src/prio_graph_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@ use {
crate::{
deserializable_packet::DeserializableTxPacket,
in_flight_tracker::InFlightTracker,
read_write_account_set::ReadWriteAccountSet,
scheduler_error::SchedulerError,
scheduler_messages::{ConsumeWork, FinishedConsumeWork, TransactionBatchId, TransactionId},
thread_aware_account_locks::{ThreadAwareAccountLocks, ThreadId, ThreadSet},
transaction_priority_id::TransactionPriorityId,
transaction_state::{SanitizedTransactionTTL, TransactionState},
transaction_state_container::TransactionStateContainer,
TARGET_NUM_TRANSACTIONS_PER_BATCH,
},
crossbeam_channel::{Receiver, Sender, TryRecvError},
itertools::izip,
prio_graph::{AccessKind, PrioGraph},
solana_core::banking_stage::{
consumer::TARGET_NUM_TRANSACTIONS_PER_BATCH, read_write_account_set::ReadWriteAccountSet,
},
solana_cost_model::block_cost_limits::MAX_BLOCK_UNITS,
solana_measure::measure_us,
solana_sdk::{
Expand Down Expand Up @@ -592,12 +591,10 @@ fn try_schedule_transaction<P: DeserializableTxPacket>(
mod tests {
use {
super::*,
crate::tests::MockImmutableDeserializedPacket,
crate::TARGET_NUM_TRANSACTIONS_PER_BATCH,
crossbeam_channel::{unbounded, Receiver},
itertools::Itertools,
solana_core::banking_stage::{
consumer::TARGET_NUM_TRANSACTIONS_PER_BATCH,
immutable_deserialized_packet::ImmutableDeserializedPacket,
},
solana_sdk::{
compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, packet::Packet,
pubkey::Pubkey, signature::Keypair, signer::Signer, system_instruction,
Expand All @@ -621,14 +618,14 @@ mod tests {
fn create_test_frame(
num_threads: usize,
) -> (
PrioGraphScheduler<ImmutableDeserializedPacket>,
PrioGraphScheduler<MockImmutableDeserializedPacket>,
Vec<Receiver<ConsumeWork>>,
Sender<FinishedConsumeWork>,
) {
let (consume_work_senders, consume_work_receivers) =
(0..num_threads).map(|_| unbounded()).unzip();
let (finished_consume_work_sender, finished_consume_work_receiver) = unbounded();
let scheduler = PrioGraphScheduler::<ImmutableDeserializedPacket>::new(
let scheduler = PrioGraphScheduler::<MockImmutableDeserializedPacket>::new(
consume_work_senders,
finished_consume_work_receiver,
);
Expand Down Expand Up @@ -668,9 +665,9 @@ mod tests {
u64,
),
>,
) -> TransactionStateContainer<ImmutableDeserializedPacket> {
) -> TransactionStateContainer<MockImmutableDeserializedPacket> {
let mut container =
TransactionStateContainer::<ImmutableDeserializedPacket>::with_capacity(10 * 1024);
TransactionStateContainer::<MockImmutableDeserializedPacket>::with_capacity(10 * 1024);
for (index, (from_keypair, to_pubkeys, lamports, compute_unit_price)) in
tx_infos.into_iter().enumerate()
{
Expand All @@ -682,7 +679,7 @@ mod tests {
compute_unit_price,
);
let packet = Arc::new(
ImmutableDeserializedPacket::new(
MockImmutableDeserializedPacket::from_packet(
Packet::from_data(None, transaction.to_versioned_transaction()).unwrap(),
)
.unwrap(),
Expand Down
Loading

0 comments on commit d554cec

Please sign in to comment.