diff --git a/necsim/partitioning/core/src/lib.rs b/necsim/partitioning/core/src/lib.rs index 6fb9537df..54a934d29 100644 --- a/necsim/partitioning/core/src/lib.rs +++ b/necsim/partitioning/core/src/lib.rs @@ -29,15 +29,24 @@ pub trait Partitioning: Sized { fn get_size(&self) -> PartitionSize; - fn with_local_partition, A: Send + Clone, Q>( + fn with_local_partition< + R: Reporter, + P: ReporterContext, + A: Data, + Q: Data + serde::Serialize + serde::de::DeserializeOwned, + >( self, reporter_context: P, auxiliary: Self::Auxiliary, args: A, inner: for<'p> fn(Self::LocalPartition<'p, R>, A) -> Q, + fold: fn(Q, Q) -> Q, ) -> anyhow::Result; } +pub trait Data: Send + Clone {} +impl Data for T {} + #[derive(Copy, Clone)] pub enum MigrationMode { Force, diff --git a/necsim/partitioning/monolithic/src/lib.rs b/necsim/partitioning/monolithic/src/lib.rs index a7c360c97..1a29426b5 100644 --- a/necsim/partitioning/monolithic/src/lib.rs +++ b/necsim/partitioning/monolithic/src/lib.rs @@ -62,12 +62,13 @@ impl Partitioning for MonolithicPartitioning { /// # Errors /// /// Returns an error if the provided event log is not empty. - fn with_local_partition, A: Send + Clone, Q>( + fn with_local_partition, A, Q>( self, reporter_context: P, event_log: Self::Auxiliary, args: A, inner: for<'p> fn(Self::LocalPartition<'p, R>, A) -> Q, + _fold: fn(Q, Q) -> Q, ) -> anyhow::Result { let local_partition = if let Some(event_log) = event_log { MonolithicLocalPartition::Recorded(Box::new( diff --git a/necsim/partitioning/mpi/src/lib.rs b/necsim/partitioning/mpi/src/lib.rs index 19188de9f..f1e00c1b3 100644 --- a/necsim/partitioning/mpi/src/lib.rs +++ b/necsim/partitioning/mpi/src/lib.rs @@ -20,7 +20,9 @@ use thiserror::Error; use necsim_core::{lineage::MigratingLineage, reporter::Reporter}; use necsim_impls_std::event_log::recorder::EventLogRecorder; -use necsim_partitioning_core::{context::ReporterContext, partition::PartitionSize, Partitioning}; +use necsim_partitioning_core::{ + context::ReporterContext, partition::PartitionSize, Data, Partitioning, +}; mod partition; mod request; @@ -169,12 +171,19 @@ impl Partitioning for MpiPartitioning { /// Returns `MissingEventLog` if the local partition is non-monolithic and /// the `event_log` is `None`. /// Returns `InvalidEventSubLog` if creating a sub-`event_log` failed. - fn with_local_partition, A: Send + Clone, Q>( + fn with_local_partition< + R: Reporter, + P: ReporterContext, + A: Data, + Q: Data + serde::Serialize + serde::de::DeserializeOwned, + >( self, reporter_context: P, event_log: Self::Auxiliary, args: A, inner: for<'p> fn(Self::LocalPartition<'p, R>, A) -> Q, + // TODO: use fold to return the same result in all partitions, then deprecate + _fold: fn(Q, Q) -> Q, ) -> anyhow::Result { let Some(event_log) = event_log else { anyhow::bail!(MpiLocalPartitionError::MissingEventLog) diff --git a/necsim/partitioning/threads/src/lib.rs b/necsim/partitioning/threads/src/lib.rs index 78ab39825..c8b67037e 100644 --- a/necsim/partitioning/threads/src/lib.rs +++ b/necsim/partitioning/threads/src/lib.rs @@ -23,7 +23,9 @@ use necsim_core::reporter::{ }; use necsim_impls_std::event_log::recorder::EventLogRecorder; -use necsim_partitioning_core::{context::ReporterContext, partition::PartitionSize, Partitioning}; +use necsim_partitioning_core::{ + context::ReporterContext, partition::PartitionSize, Data, Partitioning, +}; mod partition; mod vote; @@ -127,17 +129,24 @@ impl Partitioning for ThreadsPartitioning { self.size } + #[allow(clippy::too_many_lines)] /// # Errors /// /// Returns `MissingEventLog` if the local partition is non-monolithic and /// the `event_log` is `None`. /// Returns `InvalidEventSubLog` if creating a sub-`event_log` failed. - fn with_local_partition, A: Send + Clone, Q>( + fn with_local_partition< + R: Reporter, + P: ReporterContext, + A: Data, + Q: Data + serde::Serialize + serde::de::DeserializeOwned, + >( self, reporter_context: P, event_log: Self::Auxiliary, args: A, inner: for<'p> fn(Self::LocalPartition<'p, R>, A) -> Q, + fold: fn(Q, Q) -> Q, ) -> anyhow::Result { // TODO: add support for multithread live reporting let Some(event_log) = event_log else { @@ -188,7 +197,7 @@ impl Partitioning for ThreadsPartitioning { .map(|_| args.clone()) .collect::>(); - std::thread::scope(|scope| { + let result = std::thread::scope(|scope| { let vote_any = &vote_any; let vote_min_time = &vote_min_time; let vote_time_steps = &vote_time_steps; @@ -196,36 +205,36 @@ impl Partitioning for ThreadsPartitioning { let emigration_channels = emigration_channels.as_slice(); let sync_barrier = &sync_barrier; - for ((((partition, immigration_channel), event_log), progress_channel), args) in self + let thread_handles = self .size .partitions() .zip(immigration_channels) .zip(event_logs) .zip(progress_channels) .zip(args) - { - let thread_handle = scope.spawn::<_, ()>(move || { - let local_partition = ThreadsLocalPartition::::new( - partition, - vote_any, - vote_min_time, - vote_time_steps, - vote_termination, - emigration_channels, - immigration_channel, - self.migration_interval, - event_log, - progress_channel, - self.progress_interval, - sync_barrier, - ); - - let _result = inner(local_partition, args); - }); - - // we don't need the thread result and implicitly propagate thread panics - std::mem::drop(thread_handle); - } + .map( + |((((partition, immigration_channel), event_log), progress_channel), args)| { + scope.spawn(move || { + let local_partition = ThreadsLocalPartition::::new( + partition, + vote_any, + vote_min_time, + vote_time_steps, + vote_termination, + emigration_channels, + immigration_channel, + self.migration_interval, + event_log, + progress_channel, + self.progress_interval, + sync_barrier, + ); + + inner(local_partition, args) + }) + }, + ) + .collect::>(); let mut progress_remaining = vec![0; self.size.get() as usize].into_boxed_slice(); for (remaining, rank) in progress_receiver { @@ -239,9 +248,22 @@ impl Partitioning for ThreadsPartitioning { .into(), ); } + + let mut folded_result = None; + for handle in thread_handles { + let result = match handle.join() { + Ok(result) => result, + Err(payload) => std::panic::resume_unwind(payload), + }; + folded_result = Some(match folded_result.take() { + Some(acc) => fold(acc, result), + None => result, + }); + } + folded_result.expect("at least one thread partitioning result") }); - todo!() + Ok(result) } } diff --git a/rustcoalescence/algorithms/src/result.rs b/rustcoalescence/algorithms/src/result.rs index 6fa2d4b4d..b009fb4f4 100644 --- a/rustcoalescence/algorithms/src/result.rs +++ b/rustcoalescence/algorithms/src/result.rs @@ -8,6 +8,8 @@ use necsim_core_bond::NonNegativeF64; use necsim_impls_no_std::cogs::active_lineage_sampler::resuming::lineage::ExceptionalLineage; +#[derive(Clone, serde::Serialize, serde::Deserialize)] +#[serde(bound = "")] pub enum SimulationOutcome> { Done { time: NonNegativeF64, @@ -18,6 +20,7 @@ pub enum SimulationOutcome> { steps: u64, lineages: Vec, rng: G, + #[serde(skip)] marker: PhantomData, }, } diff --git a/rustcoalescence/src/cli/simulate/dispatch/valid/partitioning.rs b/rustcoalescence/src/cli/simulate/dispatch/valid/partitioning.rs index 6bdfc5b6b..b29c66451 100644 --- a/rustcoalescence/src/cli/simulate/dispatch/valid/partitioning.rs +++ b/rustcoalescence/src/cli/simulate/dispatch/valid/partitioning.rs @@ -4,12 +4,12 @@ use necsim_core::{ }; use necsim_core_bond::NonNegativeF64; use necsim_impls_std::event_log::recorder::EventLogRecorder; -use necsim_partitioning_core::{context::ReporterContext, Partitioning as _}; +use necsim_partitioning_core::{context::ReporterContext, LocalPartition, Partitioning as _}; use necsim_partitioning_monolithic::MonolithicLocalPartition; #[cfg(feature = "necsim-partitioning-mpi")] use necsim_partitioning_mpi::MpiLocalPartition; -use rustcoalescence_algorithms::{result::SimulationOutcome, AlgorithmDispatch}; +use rustcoalescence_algorithms::{result::SimulationOutcome, Algorithm, AlgorithmDispatch}; use rustcoalescence_scenarios::Scenario; use crate::args::config::{partitioning::Partitioning, sample::Sample}; @@ -58,7 +58,7 @@ where |partition, (sample, rng, scenario, algorithm_args, pause_before, normalised_args)| { match partition { MonolithicLocalPartition::Live(partition) => { - info::dispatch::, O, R, _>( + wrap::, O, R, _>( *partition, sample, rng, @@ -69,7 +69,7 @@ where ) }, MonolithicLocalPartition::Recorded(partition) => { - info::dispatch::, O, R, _>( + wrap::, O, R, _>( *partition, sample, rng, @@ -81,6 +81,7 @@ where }, } }, + fold, ), #[cfg(feature = "necsim-partitioning-mpi")] Partitioning::Mpi(partitioning) => partitioning.with_local_partition( @@ -90,7 +91,7 @@ where |partition, (sample, rng, scenario, algorithm_args, pause_before, normalised_args)| { match partition { MpiLocalPartition::Root(partition) => { - info::dispatch::, O, R, _>( + wrap::, O, R, _>( *partition, sample, rng, @@ -101,7 +102,7 @@ where ) }, MpiLocalPartition::Parallel(partition) => { - info::dispatch::, O, R, _>( + wrap::, O, R, _>( *partition, sample, rng, @@ -113,7 +114,71 @@ where }, } }, + fold, ), } - .flatten() + .and_then(|result| result.map_err(anyhow::Error::msg)) +} + +fn wrap< + 'p, + M: MathsCore, + G: RngCore, + A: Algorithm<'p, M, G, O, R, P>, + O: Scenario, + R: Reporter, + P: LocalPartition<'p, R>, +>( + local_partition: P, + + sample: Sample, + rng: G, + scenario: O, + algorithm_args: A::Arguments, + pause_before: Option, + + normalised_args: &BufferingSimulateArgsBuilder, +) -> Result, String> +where + Result, A::Error>: anyhow::Context, A::Error>, +{ + info::dispatch::, O, R, _>( + local_partition, + sample, + rng, + scenario, + algorithm_args, + pause_before, + normalised_args, + ) + .map_err(|err| format!("{err:?}")) +} + +fn fold>( + a: Result, String>, + b: Result, String>, +) -> Result, String> { + match (a, b) { + ( + Ok(SimulationOutcome::Done { + time: time_a, + steps: steps_a, + }), + Ok(SimulationOutcome::Done { + time: time_b, + steps: steps_b, + }), + ) => Ok(SimulationOutcome::Done { + time: time_a.max(time_b), + steps: steps_a + steps_b, + }), + (Ok(SimulationOutcome::Paused { .. }), _) | (_, Ok(SimulationOutcome::Paused { .. })) => { + Err(String::from( + "parallel pausing is not yet supported, catching this at simulation completion is \ + a bug", + )) + }, + (Err(err), Ok(_)) | (Ok(_), Err(err)) => Err(err), + (Err(err_a), Err(err_b)) => Err(format!("{err_a}\n|\n{err_b}")), + } }