diff --git a/configs/shortest_paths/coordinator.toml b/configs/shortest_paths/coordinator.toml index a04587cb..abc4c3cb 100644 --- a/configs/shortest_paths/coordinator.toml +++ b/configs/shortest_paths/coordinator.toml @@ -1,6 +1,7 @@ source = "https://www.cdc.gov/healthywater/swimming/" host = "0.0.0.0:5000" output_path = "data/shortest_paths" +max_distance = 4 [gossip] addr = "0.0.0.0:5001" diff --git a/crates/core/src/config/mod.rs b/crates/core/src/config/mod.rs index cd881c9a..4c5552da 100644 --- a/crates/core/src/config/mod.rs +++ b/crates/core/src/config/mod.rs @@ -686,6 +686,7 @@ pub struct ShortestPathCoordinatorConfig { pub gossip: GossipConfig, pub host: SocketAddr, pub output_path: String, + pub max_distance: Option, } #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] diff --git a/crates/core/src/entrypoint/ampc/shortest_path/coordinator.rs b/crates/core/src/entrypoint/ampc/shortest_path/coordinator.rs index 51f3cc7a..375ab67e 100644 --- a/crates/core/src/entrypoint/ampc/shortest_path/coordinator.rs +++ b/crates/core/src/entrypoint/ampc/shortest_path/coordinator.rs @@ -78,9 +78,12 @@ impl Setup for ShortestPathSetup { } fn setup_round(&self, dht: &Self::DhtTables) { + let meta = dht.meta.get(()).unwrap(); + dht.meta.set( (), Meta { + round: meta.round + 1, round_had_changes: false, }, ); @@ -92,18 +95,28 @@ impl Setup for ShortestPathSetup { (), Meta { round_had_changes: true, + round: 0, }, ); } } -pub struct ShortestPathFinish; +pub struct ShortestPathFinish { + max_distance: Option, +} impl Finisher for ShortestPathFinish { type Job = ShortestPathJob; fn is_finished(&self, dht: &ShortestPathTables) -> bool { - !dht.meta.get(()).unwrap().round_had_changes + let meta = dht.meta.get(()).unwrap(); + if let Some(max_distance) = self.max_distance { + if meta.round >= max_distance { + return true; + } + } + + !meta.round_had_changes } } @@ -202,7 +215,12 @@ pub fn run(config: ShortestPathCoordinatorConfig) -> Result<()> { tracing::info!("starting {} jobs", jobs.len()); let coordinator = build(&cluster.dht, cluster.workers.clone(), source); - let res = coordinator.run(jobs, ShortestPathFinish)?; + let res = coordinator.run( + jobs, + ShortestPathFinish { + max_distance: config.max_distance, + }, + )?; let output_path = Path::new(&config.output_path); diff --git a/crates/core/src/entrypoint/ampc/shortest_path/mapper.rs b/crates/core/src/entrypoint/ampc/shortest_path/mapper.rs index 378a358f..48a9257a 100644 --- a/crates/core/src/entrypoint/ampc/shortest_path/mapper.rs +++ b/crates/core/src/entrypoint/ampc/shortest_path/mapper.rs @@ -21,7 +21,7 @@ use rustc_hash::FxHashMap; use super::{ updated_nodes::{UpdatedNodes, UpdatedNodesKind}, worker::ShortestPathWorker, - DhtTable as _, Mapper, Meta, ShortestPathJob, ShortestPathTables, + DhtTable as _, Mapper, ShortestPathJob, ShortestPathTables, }; use crate::{ ampc::{ @@ -237,13 +237,12 @@ impl Mapper for ShortestPathMapper { dht.next() .changed_nodes .set(worker.shard(), new_changed_nodes.lock().unwrap().clone()); - dht.next().meta.set( - (), - Meta { - round_had_changes: round_had_changes - .load(std::sync::atomic::Ordering::Relaxed), - }, - ); + + let mut meta = dht.next().meta.get(()).unwrap(); + meta.round_had_changes = + round_had_changes.load(std::sync::atomic::Ordering::Relaxed); + + dht.next().meta.set((), meta); } ShortestPathMapper::UpdateChangedNodes => { let all_changed_nodes: Vec<_> = diff --git a/crates/core/src/entrypoint/ampc/shortest_path/mod.rs b/crates/core/src/entrypoint/ampc/shortest_path/mod.rs index 099f73fd..8ec6604e 100644 --- a/crates/core/src/entrypoint/ampc/shortest_path/mod.rs +++ b/crates/core/src/entrypoint/ampc/shortest_path/mod.rs @@ -44,6 +44,7 @@ use self::worker::{RemoteShortestPathWorker, ShortestPathWorker}; )] pub struct Meta { round_had_changes: bool, + round: u64, } #[derive(bincode::Encode, bincode::Decode, Debug, Clone)]