Skip to content

Commit

Permalink
feat(torii-core): parallelization (#2423)
Browse files Browse the repository at this point in the history
  • Loading branch information
Larkooo authored Sep 17, 2024
1 parent 9b5cebd commit 91a0fd0
Show file tree
Hide file tree
Showing 11 changed files with 356 additions and 123 deletions.
38 changes: 23 additions & 15 deletions bin/torii/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
//! documentation for usage details. This is **not recommended on Windows**. See [here](https://rust-lang.github.io/rfcs/1974-global-allocators.html#jemalloc)
//! for more info.
use std::cmp;
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
Expand Down Expand Up @@ -125,6 +126,10 @@ struct Args {
/// Polling interval in ms
#[arg(long, default_value = "500")]
polling_interval: u64,

/// Max concurrent tasks
#[arg(long, default_value = "100")]
max_concurrent_tasks: usize,
}

#[tokio::main]
Expand Down Expand Up @@ -157,32 +162,34 @@ async fn main() -> anyhow::Result<()> {
.connect_with(options)
.await?;

if args.database == ":memory:" {
// Disable auto-vacuum
sqlx::query("PRAGMA auto_vacuum = NONE;").execute(&pool).await?;
// Disable auto-vacuum
sqlx::query("PRAGMA auto_vacuum = NONE;").execute(&pool).await?;
sqlx::query("PRAGMA journal_mode = WAL;").execute(&pool).await?;
sqlx::query("PRAGMA synchronous = NORMAL;").execute(&pool).await?;

// Switch DELETE journal mode
sqlx::query("PRAGMA journal_mode=DELETE;").execute(&pool).await?;
}
// Set the number of threads based on CPU count
let cpu_count = std::thread::available_parallelism().unwrap().get();
let thread_count = cmp::min(cpu_count, 8);
sqlx::query(&format!("PRAGMA threads = {};", thread_count)).execute(&pool).await?;

sqlx::migrate!("../../crates/torii/migrations").run(&pool).await?;

let provider: Arc<_> = JsonRpcClient::new(HttpTransport::new(args.rpc)).into();

// Get world address
let world = WorldContractReader::new(args.world_address, &provider);
let world = WorldContractReader::new(args.world_address, provider.clone());

let db = Sql::new(pool.clone(), args.world_address).await?;

let processors = Processors {
event: generate_event_processors_map(vec![
Box::new(RegisterModelProcessor),
Box::new(StoreSetRecordProcessor),
Box::new(MetadataUpdateProcessor),
Box::new(StoreDelRecordProcessor),
Box::new(EventMessageProcessor),
Box::new(StoreUpdateRecordProcessor),
Box::new(StoreUpdateMemberProcessor),
Arc::new(RegisterModelProcessor),
Arc::new(StoreSetRecordProcessor),
Arc::new(MetadataUpdateProcessor),
Arc::new(StoreDelRecordProcessor),
Arc::new(EventMessageProcessor),
Arc::new(StoreUpdateRecordProcessor),
Arc::new(StoreUpdateMemberProcessor),
])?,
transaction: vec![Box::new(StoreTransactionProcessor)],
..Processors::default()
Expand All @@ -193,9 +200,10 @@ async fn main() -> anyhow::Result<()> {
let mut engine = Engine::new(
world,
db.clone(),
&provider,
provider.clone(),
processors,
EngineConfig {
max_concurrent_tasks: args.max_concurrent_tasks,
start_block: args.start_block,
events_chunk_size: args.events_chunk_size,
index_pending: args.index_pending,
Expand Down
5 changes: 5 additions & 0 deletions crates/torii/core/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ impl ModelCache {
Ok(model)
}

pub async fn set(&self, selector: Felt, model: Model) {
let mut cache = self.cache.write().await;
cache.insert(selector, model);
}

pub async fn clear(&self) {
self.cache.write().await.clear();
}
Expand Down
133 changes: 108 additions & 25 deletions crates/torii/core/src/engine.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::collections::{BTreeMap, HashMap};
use std::fmt::Debug;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
use std::time::Duration;

use anyhow::Result;
Expand All @@ -13,6 +15,8 @@ use starknet::core::types::{
use starknet::providers::Provider;
use tokio::sync::broadcast::Sender;
use tokio::sync::mpsc::Sender as BoundedSender;
use tokio::sync::Semaphore;
use tokio::task::JoinSet;
use tokio::time::sleep;
use tracing::{debug, error, info, trace, warn};

Expand All @@ -21,14 +25,14 @@ use crate::processors::{BlockProcessor, EventProcessor, TransactionProcessor};
use crate::sql::Sql;

#[allow(missing_debug_implementations)]
pub struct Processors<P: Provider + Send + Sync + std::fmt::Debug> {
pub struct Processors<P: Provider + Send + Sync + std::fmt::Debug + 'static> {
pub block: Vec<Box<dyn BlockProcessor<P>>>,
pub transaction: Vec<Box<dyn TransactionProcessor<P>>>,
pub event: HashMap<Felt, Box<dyn EventProcessor<P>>>,
pub event: HashMap<Felt, Arc<dyn EventProcessor<P>>>,
pub catch_all_event: Box<dyn EventProcessor<P>>,
}

impl<P: Provider + Send + Sync + std::fmt::Debug> Default for Processors<P> {
impl<P: Provider + Send + Sync + std::fmt::Debug + 'static> Default for Processors<P> {
fn default() -> Self {
Self {
block: vec![],
Expand All @@ -48,6 +52,7 @@ pub struct EngineConfig {
pub start_block: u64,
pub events_chunk_size: u64,
pub index_pending: bool,
pub max_concurrent_tasks: usize,
}

impl Default for EngineConfig {
Expand All @@ -57,6 +62,7 @@ impl Default for EngineConfig {
start_block: 0,
events_chunk_size: 1024,
index_pending: true,
max_concurrent_tasks: 100,
}
}
}
Expand All @@ -83,23 +89,32 @@ pub struct FetchPendingResult {
pub block_number: u64,
}

#[derive(Debug)]
pub struct ParallelizedEvent {
pub block_number: u64,
pub block_timestamp: u64,
pub event_id: String,
pub event: Event,
}

#[allow(missing_debug_implementations)]
pub struct Engine<P: Provider + Send + Sync + std::fmt::Debug> {
world: WorldContractReader<P>,
pub struct Engine<P: Provider + Send + Sync + std::fmt::Debug + 'static> {
world: Arc<WorldContractReader<P>>,
db: Sql,
provider: Box<P>,
processors: Processors<P>,
processors: Arc<Processors<P>>,
config: EngineConfig,
shutdown_tx: Sender<()>,
block_tx: Option<BoundedSender<u64>>,
tasks: HashMap<u64, Vec<ParallelizedEvent>>,
}

struct UnprocessedEvent {
keys: Vec<String>,
data: Vec<String>,
}

impl<P: Provider + Send + Sync + std::fmt::Debug> Engine<P> {
impl<P: Provider + Send + Sync + std::fmt::Debug + 'static> Engine<P> {
pub fn new(
world: WorldContractReader<P>,
db: Sql,
Expand All @@ -109,7 +124,16 @@ impl<P: Provider + Send + Sync + std::fmt::Debug> Engine<P> {
shutdown_tx: Sender<()>,
block_tx: Option<BoundedSender<u64>>,
) -> Self {
Self { world, db, provider: Box::new(provider), processors, config, shutdown_tx, block_tx }
Self {
world: Arc::new(world),
db,
provider: Box::new(provider),
processors: Arc::new(processors),
config,
shutdown_tx,
block_tx,
tasks: HashMap::new(),
}
}

pub async fn start(&mut self) -> Result<()> {
Expand Down Expand Up @@ -397,11 +421,14 @@ impl<P: Provider + Send + Sync + std::fmt::Debug> Engine<P> {
}
}

// Process parallelized events
self.process_tasks().await?;

// Set the head to the last processed pending transaction
// Head block number should still be latest block number
self.db.set_head(data.block_number - 1, last_pending_block_world_tx, last_pending_block_tx);

self.db.execute().await?;

Ok(())
}

Expand Down Expand Up @@ -436,18 +463,55 @@ impl<P: Provider + Send + Sync + std::fmt::Debug> Engine<P> {
}
}

// We return None for the pending_block_tx because our process_range
// gets only specific events from the world. so some transactions
// might get ignored and wont update the cursor.
// so once the sync range is done, we assume all of the tx of the block
// have been processed.
// Process parallelized events
self.process_tasks().await?;

self.db.set_head(data.latest_block_number, None, None);
self.db.execute().await?;

Ok(())
}

async fn process_tasks(&mut self) -> Result<()> {
// We use a semaphore to limit the number of concurrent tasks
let semaphore = Arc::new(Semaphore::new(self.config.max_concurrent_tasks));

// Run all tasks concurrently
let mut set = JoinSet::new();
for (task_id, events) in self.tasks.drain() {
let db = self.db.clone();
let world = self.world.clone();
let processors = self.processors.clone();
let semaphore = semaphore.clone();

set.spawn(async move {
let _permit = semaphore.acquire().await.unwrap();
let mut local_db = db.clone();
for ParallelizedEvent { event_id, event, block_number, block_timestamp } in events {
if let Some(processor) = processors.event.get(&event.keys[0]) {
debug!(target: LOG_TARGET, event_name = processor.event_key(), task_id = %task_id, "Processing parallelized event.");

if let Err(e) = processor
.process(&world, &mut local_db, block_number, block_timestamp, &event_id, &event)
.await
{
error!(target: LOG_TARGET, event_name = processor.event_key(), error = %e, task_id = %task_id, "Processing parallelized event.");
}
}
}
Ok::<_, anyhow::Error>(local_db)
});
}

// Join all tasks
while let Some(result) = set.join_next().await {
let local_db = result??;
self.db.merge(local_db)?;
}

Ok(())
}

async fn get_block_timestamp(&self, block_number: u64) -> Result<u64> {
match self.provider.get_block_with_tx_hashes(BlockId::Number(block_number)).await? {
MaybePendingBlockWithTxHashes::Block(block) => Ok(block.timestamp),
Expand Down Expand Up @@ -477,7 +541,7 @@ impl<P: Provider + Send + Sync + std::fmt::Debug> Engine<P> {
block_timestamp,
&event_id,
&event,
transaction_hash,
// transaction_hash,
)
.await?;
}
Expand Down Expand Up @@ -527,7 +591,7 @@ impl<P: Provider + Send + Sync + std::fmt::Debug> Engine<P> {
block_timestamp,
&event_id,
event,
*transaction_hash,
// *transaction_hash,
)
.await?;
}
Expand Down Expand Up @@ -587,9 +651,9 @@ impl<P: Provider + Send + Sync + std::fmt::Debug> Engine<P> {
block_timestamp: u64,
event_id: &str,
event: &Event,
transaction_hash: Felt,
// transaction_hash: Felt,
) -> Result<()> {
self.db.store_event(event_id, event, transaction_hash, block_timestamp);
// self.db.store_event(event_id, event, transaction_hash, block_timestamp);
let event_key = event.keys[0];

let Some(processor) = self.processors.event.get(&event_key) else {
Expand Down Expand Up @@ -627,14 +691,33 @@ impl<P: Provider + Send + Sync + std::fmt::Debug> Engine<P> {
return Ok(());
};

// if processor.validate(event) {
if let Err(e) = processor
.process(&self.world, &mut self.db, block_number, block_timestamp, event_id, event)
.await
{
error!(target: LOG_TARGET, event_name = processor.event_key(), error = %e, "Processing event.");
let task_identifier = match processor.event_key().as_str() {
"StoreSetRecord" | "StoreUpdateRecord" | "StoreUpdateMember" | "StoreDelRecord" => {
let mut hasher = DefaultHasher::new();
event.data[0].hash(&mut hasher);
event.data[1].hash(&mut hasher);
hasher.finish()
}
_ => 0,
};

// if we have a task identifier, we queue the event to be parallelized
if task_identifier != 0 {
self.tasks.entry(task_identifier).or_default().push(ParallelizedEvent {
event_id: event_id.to_string(),
event: event.clone(),
block_number,
block_timestamp,
});
} else {
// if we dont have a task identifier, we process the event immediately
if let Err(e) = processor
.process(&self.world, &mut self.db, block_number, block_timestamp, event_id, event)
.await
{
error!(target: LOG_TARGET, event_name = processor.event_key(), error = %e, "Processing event.");
}
}
// }

Ok(())
}
Expand Down
11 changes: 6 additions & 5 deletions crates/torii/core/src/processors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;

use anyhow::{Error, Result};
use async_trait::async_trait;
Expand All @@ -23,7 +24,7 @@ const ENTITY_ID_INDEX: usize = 1;
const NUM_KEYS_INDEX: usize = 2;

#[async_trait]
pub trait EventProcessor<P>
pub trait EventProcessor<P>: Send + Sync
where
P: Provider + Sync,
{
Expand All @@ -48,7 +49,7 @@ where
}

#[async_trait]
pub trait BlockProcessor<P: Provider + Sync> {
pub trait BlockProcessor<P: Provider + Sync>: Send + Sync {
fn get_block_number(&self) -> String;
async fn process(
&self,
Expand All @@ -60,7 +61,7 @@ pub trait BlockProcessor<P: Provider + Sync> {
}

#[async_trait]
pub trait TransactionProcessor<P: Provider + Sync> {
pub trait TransactionProcessor<P: Provider + Sync>: Send + Sync {
#[allow(clippy::too_many_arguments)]
async fn process(
&self,
Expand All @@ -75,8 +76,8 @@ pub trait TransactionProcessor<P: Provider + Sync> {

/// Given a list of event processors, generate a map of event keys to the event processor
pub fn generate_event_processors_map<P: Provider + Sync + Send>(
event_processor: Vec<Box<dyn EventProcessor<P>>>,
) -> Result<HashMap<Felt, Box<dyn EventProcessor<P>>>> {
event_processor: Vec<Arc<dyn EventProcessor<P>>>,
) -> Result<HashMap<Felt, Arc<dyn EventProcessor<P>>>> {
let mut event_processors = HashMap::new();

for processor in event_processor {
Expand Down
Loading

0 comments on commit 91a0fd0

Please sign in to comment.