Skip to content

Commit

Permalink
Add minimal partitioning folding support
Browse files Browse the repository at this point in the history
  • Loading branch information
juntyr committed May 26, 2024
1 parent 7a9568b commit 74a85f5
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 39 deletions.
11 changes: 10 additions & 1 deletion necsim/partitioning/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,24 @@ pub trait Partitioning: Sized {

fn get_size(&self) -> PartitionSize;

fn with_local_partition<R: Reporter, P: ReporterContext<Reporter = R>, A: Send + Clone, Q>(
fn with_local_partition<
R: Reporter,
P: ReporterContext<Reporter = R>,
A: Data,
Q: Data + serde::Serialize + serde::de::DeserializeOwned,
>(
self,
reporter_context: P,
auxiliary: Self::Auxiliary,
args: A,
inner: for<'p> fn(Self::LocalPartition<'p, R>, A) -> Q,
fold: fn(Q, Q) -> Q,
) -> anyhow::Result<Q>;
}

pub trait Data: Send + Clone {}
impl<T: Send + Clone> Data for T {}

#[derive(Copy, Clone)]
pub enum MigrationMode {
Force,
Expand Down
3 changes: 2 additions & 1 deletion necsim/partitioning/monolithic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,13 @@ impl Partitioning for MonolithicPartitioning {
/// # Errors
///
/// Returns an error if the provided event log is not empty.
fn with_local_partition<R: Reporter, P: ReporterContext<Reporter = R>, A: Send + Clone, Q>(
fn with_local_partition<R: Reporter, P: ReporterContext<Reporter = R>, A, Q>(
self,
reporter_context: P,
event_log: Self::Auxiliary,
args: A,
inner: for<'p> fn(Self::LocalPartition<'p, R>, A) -> Q,
_fold: fn(Q, Q) -> Q,
) -> anyhow::Result<Q> {
let local_partition = if let Some(event_log) = event_log {
MonolithicLocalPartition::Recorded(Box::new(
Expand Down
13 changes: 11 additions & 2 deletions necsim/partitioning/mpi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ use thiserror::Error;
use necsim_core::{lineage::MigratingLineage, reporter::Reporter};

use necsim_impls_std::event_log::recorder::EventLogRecorder;
use necsim_partitioning_core::{context::ReporterContext, partition::PartitionSize, Partitioning};
use necsim_partitioning_core::{
context::ReporterContext, partition::PartitionSize, Data, Partitioning,
};

mod partition;
mod request;
Expand Down Expand Up @@ -169,12 +171,19 @@ impl Partitioning for MpiPartitioning {
/// Returns `MissingEventLog` if the local partition is non-monolithic and
/// the `event_log` is `None`.
/// Returns `InvalidEventSubLog` if creating a sub-`event_log` failed.
fn with_local_partition<R: Reporter, P: ReporterContext<Reporter = R>, A: Send + Clone, Q>(
fn with_local_partition<
R: Reporter,
P: ReporterContext<Reporter = R>,
A: Data,
Q: Data + serde::Serialize + serde::de::DeserializeOwned,
>(
self,
reporter_context: P,
event_log: Self::Auxiliary,
args: A,
inner: for<'p> fn(Self::LocalPartition<'p, R>, A) -> Q,
// TODO: use fold to return the same result in all partitions, then deprecate
_fold: fn(Q, Q) -> Q,
) -> anyhow::Result<Q> {
let Some(event_log) = event_log else {
anyhow::bail!(MpiLocalPartitionError::MissingEventLog)
Expand Down
78 changes: 50 additions & 28 deletions necsim/partitioning/threads/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ use necsim_core::reporter::{
};

use necsim_impls_std::event_log::recorder::EventLogRecorder;
use necsim_partitioning_core::{context::ReporterContext, partition::PartitionSize, Partitioning};
use necsim_partitioning_core::{
context::ReporterContext, partition::PartitionSize, Data, Partitioning,
};

mod partition;
mod vote;
Expand Down Expand Up @@ -127,17 +129,24 @@ impl Partitioning for ThreadsPartitioning {
self.size
}

#[allow(clippy::too_many_lines)]
/// # Errors
///
/// Returns `MissingEventLog` if the local partition is non-monolithic and
/// the `event_log` is `None`.
/// Returns `InvalidEventSubLog` if creating a sub-`event_log` failed.
fn with_local_partition<R: Reporter, P: ReporterContext<Reporter = R>, A: Send + Clone, Q>(
fn with_local_partition<
R: Reporter,
P: ReporterContext<Reporter = R>,
A: Data,
Q: Data + serde::Serialize + serde::de::DeserializeOwned,
>(
self,
reporter_context: P,
event_log: Self::Auxiliary,
args: A,
inner: for<'p> fn(Self::LocalPartition<'p, R>, A) -> Q,
fold: fn(Q, Q) -> Q,
) -> anyhow::Result<Q> {
// TODO: add support for multithread live reporting
let Some(event_log) = event_log else {
Expand Down Expand Up @@ -188,44 +197,44 @@ impl Partitioning for ThreadsPartitioning {
.map(|_| args.clone())
.collect::<Vec<_>>();

std::thread::scope(|scope| {
let result = std::thread::scope(|scope| {
let vote_any = &vote_any;
let vote_min_time = &vote_min_time;
let vote_time_steps = &vote_time_steps;
let vote_termination = &vote_termination;
let emigration_channels = emigration_channels.as_slice();
let sync_barrier = &sync_barrier;

for ((((partition, immigration_channel), event_log), progress_channel), args) in self
let thread_handles = self
.size
.partitions()
.zip(immigration_channels)
.zip(event_logs)
.zip(progress_channels)
.zip(args)
{
let thread_handle = scope.spawn::<_, ()>(move || {
let local_partition = ThreadsLocalPartition::<R>::new(
partition,
vote_any,
vote_min_time,
vote_time_steps,
vote_termination,
emigration_channels,
immigration_channel,
self.migration_interval,
event_log,
progress_channel,
self.progress_interval,
sync_barrier,
);

let _result = inner(local_partition, args);
});

// we don't need the thread result and implicitly propagate thread panics
std::mem::drop(thread_handle);
}
.map(
|((((partition, immigration_channel), event_log), progress_channel), args)| {
scope.spawn(move || {
let local_partition = ThreadsLocalPartition::<R>::new(
partition,
vote_any,
vote_min_time,
vote_time_steps,
vote_termination,
emigration_channels,
immigration_channel,
self.migration_interval,
event_log,
progress_channel,
self.progress_interval,
sync_barrier,
);

inner(local_partition, args)
})
},
)
.collect::<Vec<_>>();

let mut progress_remaining = vec![0; self.size.get() as usize].into_boxed_slice();
for (remaining, rank) in progress_receiver {
Expand All @@ -239,9 +248,22 @@ impl Partitioning for ThreadsPartitioning {
.into(),
);
}

let mut folded_result = None;
for handle in thread_handles {
let result = match handle.join() {
Ok(result) => result,
Err(payload) => std::panic::resume_unwind(payload),
};
folded_result = Some(match folded_result.take() {
Some(acc) => fold(acc, result),
None => result,
});
}
folded_result.expect("at least one thread partitioning result")
});

todo!()
Ok(result)
}
}

Expand Down
3 changes: 3 additions & 0 deletions rustcoalescence/algorithms/src/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use necsim_core_bond::NonNegativeF64;

use necsim_impls_no_std::cogs::active_lineage_sampler::resuming::lineage::ExceptionalLineage;

#[derive(Clone, serde::Serialize, serde::Deserialize)]
#[serde(bound = "")]
pub enum SimulationOutcome<M: MathsCore, G: RngCore<M>> {
Done {
time: NonNegativeF64,
Expand All @@ -18,6 +20,7 @@ pub enum SimulationOutcome<M: MathsCore, G: RngCore<M>> {
steps: u64,
lineages: Vec<Lineage>,
rng: G,
#[serde(skip)]
marker: PhantomData<M>,
},
}
Expand Down
79 changes: 72 additions & 7 deletions rustcoalescence/src/cli/simulate/dispatch/valid/partitioning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ use necsim_core::{
};
use necsim_core_bond::NonNegativeF64;
use necsim_impls_std::event_log::recorder::EventLogRecorder;
use necsim_partitioning_core::{context::ReporterContext, Partitioning as _};
use necsim_partitioning_core::{context::ReporterContext, LocalPartition, Partitioning as _};

use necsim_partitioning_monolithic::MonolithicLocalPartition;
#[cfg(feature = "necsim-partitioning-mpi")]
use necsim_partitioning_mpi::MpiLocalPartition;
use rustcoalescence_algorithms::{result::SimulationOutcome, AlgorithmDispatch};
use rustcoalescence_algorithms::{result::SimulationOutcome, Algorithm, AlgorithmDispatch};
use rustcoalescence_scenarios::Scenario;

use crate::args::config::{partitioning::Partitioning, sample::Sample};
Expand Down Expand Up @@ -58,7 +58,7 @@ where
|partition, (sample, rng, scenario, algorithm_args, pause_before, normalised_args)| {
match partition {
MonolithicLocalPartition::Live(partition) => {
info::dispatch::<M, G, A::Algorithm<'_, _>, O, R, _>(
wrap::<M, G, A::Algorithm<'_, _>, O, R, _>(
*partition,
sample,
rng,
Expand All @@ -69,7 +69,7 @@ where
)
},
MonolithicLocalPartition::Recorded(partition) => {
info::dispatch::<M, G, A::Algorithm<'_, _>, O, R, _>(
wrap::<M, G, A::Algorithm<'_, _>, O, R, _>(
*partition,
sample,
rng,
Expand All @@ -81,6 +81,7 @@ where
},
}
},
fold,
),
#[cfg(feature = "necsim-partitioning-mpi")]
Partitioning::Mpi(partitioning) => partitioning.with_local_partition(
Expand All @@ -90,7 +91,7 @@ where
|partition, (sample, rng, scenario, algorithm_args, pause_before, normalised_args)| {
match partition {
MpiLocalPartition::Root(partition) => {
info::dispatch::<M, G, A::Algorithm<'_, _>, O, R, _>(
wrap::<M, G, A::Algorithm<'_, _>, O, R, _>(
*partition,
sample,
rng,
Expand All @@ -101,7 +102,7 @@ where
)
},
MpiLocalPartition::Parallel(partition) => {
info::dispatch::<M, G, A::Algorithm<'_, _>, O, R, _>(
wrap::<M, G, A::Algorithm<'_, _>, O, R, _>(
*partition,
sample,
rng,
Expand All @@ -113,7 +114,71 @@ where
},
}
},
fold,
),
}
.flatten()
.and_then(|result| result.map_err(anyhow::Error::msg))
}

fn wrap<
'p,
M: MathsCore,
G: RngCore<M>,
A: Algorithm<'p, M, G, O, R, P>,
O: Scenario<M, G>,
R: Reporter,
P: LocalPartition<'p, R>,
>(
local_partition: P,

sample: Sample,
rng: G,
scenario: O,
algorithm_args: A::Arguments,
pause_before: Option<NonNegativeF64>,

normalised_args: &BufferingSimulateArgsBuilder,
) -> Result<SimulationOutcome<M, G>, String>
where
Result<SimulationOutcome<M, G>, A::Error>: anyhow::Context<SimulationOutcome<M, G>, A::Error>,
{
info::dispatch::<M, G, A::Algorithm<'_, _>, O, R, _>(
local_partition,
sample,
rng,
scenario,
algorithm_args,
pause_before,
normalised_args,
)
.map_err(|err| format!("{err:?}"))
}

fn fold<M: MathsCore, G: RngCore<M>>(
a: Result<SimulationOutcome<M, G>, String>,
b: Result<SimulationOutcome<M, G>, String>,
) -> Result<SimulationOutcome<M, G>, String> {
match (a, b) {
(
Ok(SimulationOutcome::Done {
time: time_a,
steps: steps_a,
}),
Ok(SimulationOutcome::Done {
time: time_b,
steps: steps_b,
}),
) => Ok(SimulationOutcome::Done {
time: time_a.max(time_b),
steps: steps_a + steps_b,
}),
(Ok(SimulationOutcome::Paused { .. }), _) | (_, Ok(SimulationOutcome::Paused { .. })) => {
Err(String::from(
"parallel pausing is not yet supported, catching this at simulation completion is \
a bug",
))
},
(Err(err), Ok(_)) | (Ok(_), Err(err)) => Err(err),
(Err(err_a), Err(err_b)) => Err(format!("{err_a}\n|\n{err_b}")),
}
}

0 comments on commit 74a85f5

Please sign in to comment.