diff --git a/README.md b/README.md index 302d112..5f0561c 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,19 @@ IONMESH_PROFILE_NUM_WINDOWS # windows. ``` +### Profiling + +Mostly notes to self but someone else might benefit ... + +``` +# Making a cool flamegraph of a run +sudo RUST_LOG=debug CARGO_PROFILE_RELEASE_DEBUG=true IONMESH_PROFILE_NUM_WINDOWS=2 \ + cargo flamegraph \ + --output flamegraph_secondversion.svg \ + --features less_parallel \ + -- --config ./benchmark/default_ionmesh_config.toml benchmark/${MYFAVORITEFILE}.d -o tmp/ +``` + ## Roadmap 1. Use aggregation metrics to re-score sage search. diff --git a/src/aggregation/aggregators.rs b/src/aggregation/aggregators.rs index ce23c0b..6d4bdec 100644 --- a/src/aggregation/aggregators.rs +++ b/src/aggregation/aggregators.rs @@ -181,33 +181,35 @@ fn parallel_aggregate_clusters< let mut cluster_vecs = out2.into_iter().flatten().collect::>(); - let unclustered_elems: Vec = cluster_labels - .iter() - .enumerate() - .filter(|(_, x)| match x { - ClusterLabel::Unassigned => true, - ClusterLabel::Noise => keep_unclustered, - _ => false, - }) - .map(|(i, _elem)| i) - .collect(); + if keep_unclustered { + let unclustered_elems: Vec = cluster_labels + .iter() + .enumerate() + .filter(|(_, x)| match x { + ClusterLabel::Unassigned => true, // Should there be any unassigned? + ClusterLabel::Noise => true, + ClusterLabel::Cluster(_) => false, + }) + .map(|(i, _elem)| i) + .collect(); - // if unclustered_elems.len() > 0 { - // log::debug!("Total Orig elems: {}", cluster_labels.len()); - // log::debug!("Unclustered elems: {}", unclustered_elems.len()); - // log::debug!("Clustered elems: {}", cluster_vecs.len()); - // } + // if unclustered_elems.len() > 0 { + // log::debug!("Total Orig elems: {}", cluster_labels.len()); + // log::debug!("Unclustered elems: {}", unclustered_elems.len()); + // log::debug!("Clustered elems: {}", cluster_vecs.len()); + // } - let unclustered_elems = unclustered_elems - .iter() - .map(|i| { - let mut oe = def_aggregator(); - oe.add(&elements.get_aggregable_at_index(*i)); - oe - }) - .collect::>(); + let unclustered_elems = unclustered_elems + .iter() + .map(|i| { + let mut oe = def_aggregator(); + oe.add(&elements.get_aggregable_at_index(*i)); + oe + }) + .collect::>(); - cluster_vecs.extend(unclustered_elems); + cluster_vecs.extend(unclustered_elems); + } timer.stop(true); cluster_vecs diff --git a/src/aggregation/chromatograms.rs b/src/aggregation/chromatograms.rs index 17e976b..37c7b2d 100644 --- a/src/aggregation/chromatograms.rs +++ b/src/aggregation/chromatograms.rs @@ -254,7 +254,12 @@ impl< let mut mag_a = T::default(); let mut mag_b = T::default(); for i in 0..NBINS { - let other_index = i + other_vs_self_offset as usize; + let other_index = i as i32 + other_vs_self_offset; + if other_index < 0 { + continue; + } + + let other_index = other_index as usize; if other_index >= other.chromatogram.len() { continue; } diff --git a/src/aggregation/dbscan/dbscan.rs b/src/aggregation/dbscan/dbscan.rs index 7d78f4c..ac4408d 100644 --- a/src/aggregation/dbscan/dbscan.rs +++ b/src/aggregation/dbscan/dbscan.rs @@ -20,7 +20,7 @@ fn reassign_centroid< const N: usize, T: Send + Clone + Copy, C: NDPointConverter, - I: QueriableIndexedPoints<'a, N> + std::marker::Sync, + I: QueriableIndexedPoints + std::marker::Sync, G: Sync + Send + ClusterAggregator, R: Send, RE: Send + Sync + AsAggregableAtIndex + ?Sized, @@ -167,7 +167,7 @@ pub fn dbscan_aggregate< + Sync + std::fmt::Debug + ?Sized, - IND: QueriableIndexedPoints<'a, N> + std::marker::Sync + Send + std::fmt::Debug, + IND: QueriableIndexedPoints + std::marker::Sync + Send + std::fmt::Debug, NAI: AsNDPointsAtIndex + std::marker::Sync + Send, T: HasIntensity + Send + Clone + Copy + Sync, D: Send + Sync, diff --git a/src/aggregation/dbscan/runner.rs b/src/aggregation/dbscan/runner.rs index 007ec8f..f161d91 100644 --- a/src/aggregation/dbscan/runner.rs +++ b/src/aggregation/dbscan/runner.rs @@ -1,13 +1,14 @@ use crate::space::space_generics::{ convert_to_bounds_query, AsNDPointsAtIndex, DistantAtIndex, HasIntensity, IntenseAtIndex, - NDPoint, QueriableIndexedPoints, + NDBoundary, NDPoint, QueriableIndexedPoints, }; use crate::space::space_generics::{AsAggregableAtIndex, NDPointConverter}; -use std::marker::PhantomData; - use crate::utils; +use core::fmt::Debug; use indicatif::ProgressIterator; -use log::trace; +use log::{debug, trace}; +use std::marker::PhantomData; +use std::sync::Arc; use crate::aggregation::aggregators::ClusterLabel; use crate::aggregation::dbscan::utils::FilterFunCache; @@ -105,6 +106,7 @@ impl ClusterLabels { struct DBScanTimers { main: utils::ContextTimer, + // TODO aux timers can probably be a hashmap filter_fun_cache_timer: utils::ContextTimer, outer_loop_nn_timer: utils::ContextTimer, inner_loop_nn_timer: utils::ContextTimer, @@ -209,20 +211,45 @@ impl DBSCANRunnerState { } } -struct DBSCANRunner<'a, const N: usize, D> { +struct DBSCANRunner<'a, const N: usize, D, FF> +where + FF: Fn(&D) -> bool + Send + Sync + ?Sized, + D: Send + Sync, +{ min_n: usize, min_intensity: u64, - filter_fun: Option<&'a (dyn Fn(&D) -> bool + Send + Sync)>, + filter_fun: Option<&'a FF>, progress: bool, max_extension_distances: &'a [f32; N], + _phantom: PhantomData, } +impl<'a, const N: usize, D, FF> Debug for DBSCANRunner<'a, N, D, FF> +where + FF: Fn(&D) -> bool + Send + Sync + ?Sized, + D: Send + Sync, +{ + fn fmt( + &self, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + f.debug_struct("DBSCANRunner") + .field("min_n", &self.min_n) + .field("min_intensity", &self.min_intensity) + .field("filter_fun", &"Some bool>???") + .field("progress", &self.progress) + .field("max_extension_distances", &self.max_extension_distances) + .finish() + } +} + +#[derive(Clone)] struct DBSCANPoints<'a, const N: usize, PP, PE, DAI, E, QIP> where - PP: IntenseAtIndex + std::marker::Send + ?Sized, - PE: AsNDPointsAtIndex + ?Sized, + PP: IntenseAtIndex + Send + Sync + ?Sized, + PE: AsNDPointsAtIndex + Send + Sync + ?Sized, DAI: DistantAtIndex + ?Sized, - QIP: QueriableIndexedPoints<'a, N> + std::marker::Sync, + QIP: QueriableIndexedPoints + Sync, { raw_elements: &'a PP, // &'a Vec, intensity_sorted_indices: Vec<(usize, u64)>, @@ -232,54 +259,117 @@ where _phantom_metric: PhantomData, } -impl<'a, const N: usize, PP, QQ, D, E, QIP> DBSCANPoints<'a, N, PP, QQ, D, E, QIP> +impl<'a, const N: usize, PP, QQ, DAI, E, QIP> QueriableIndexedPoints + for DBSCANPoints<'a, N, PP, QQ, DAI, E, QIP> +where + PP: IntenseAtIndex + Send + Sync + ?Sized, + QQ: AsNDPointsAtIndex + Send + Sync + ?Sized, + DAI: DistantAtIndex + ?Sized, + QIP: QueriableIndexedPoints + Sync, +{ + fn query_ndpoint( + &self, + point: &NDPoint, + ) -> Vec { + self.indexed_points.query_ndpoint(point) + } + + fn query_ndrange( + &self, + boundary: &NDBoundary, + reference_point: Option<&NDPoint>, + ) -> Vec { + self.indexed_points.query_ndrange(boundary, reference_point) + } +} + +impl<'a, const N: usize, PP, QQ, DAI, E, QIP> DistantAtIndex + for DBSCANPoints<'a, N, PP, QQ, DAI, E, QIP> where - PP: IntenseAtIndex + std::marker::Send + ?Sized, - QQ: AsNDPointsAtIndex + ?Sized, + PP: IntenseAtIndex + Sync + Send + ?Sized, + QQ: AsNDPointsAtIndex + Send + Sync + ?Sized, + DAI: DistantAtIndex + ?Sized, + QIP: QueriableIndexedPoints + std::marker::Sync, +{ + fn distance_at_indices( + &self, + a: usize, + b: usize, + ) -> E { + self.raw_dist.distance_at_indices(a, b) + } +} + +impl<'a, const N: usize, PP, QQ, D, E, QIP> IntenseAtIndex + for DBSCANPoints<'a, N, PP, QQ, D, E, QIP> +where + PP: IntenseAtIndex + std::marker::Send + Sync + ?Sized, + QQ: AsNDPointsAtIndex + Send + Sync + ?Sized, D: DistantAtIndex + ?Sized, - QIP: QueriableIndexedPoints<'a, N> + std::marker::Sync, + QIP: QueriableIndexedPoints + std::marker::Sync, { - fn get_intensity_at_index( + fn intensity_at_index( &self, index: usize, ) -> u64 { self.raw_elements.intensity_at_index(index) } - fn get_ndpoint_at_index( + fn weight_at_index( + &self, + index: usize, + ) -> u64 { + self.raw_elements.weight_at_index(index) + } + + fn intensity_index_length(&self) -> usize { + self.raw_elements.intensity_index_length() + } +} + +impl<'a, const N: usize, PP, QQ, D, E, QIP> AsNDPointsAtIndex + for DBSCANPoints<'a, N, PP, QQ, D, E, QIP> +where + PP: IntenseAtIndex + std::marker::Send + Sync + ?Sized, + QQ: AsNDPointsAtIndex + Send + Sync + ?Sized, + D: DistantAtIndex + ?Sized, + QIP: QueriableIndexedPoints + std::marker::Sync, +{ + fn get_ndpoint( &self, index: usize, ) -> NDPoint { self.projected_elements.get_ndpoint(index) } - fn get_distance_at_indices( - &self, - a: usize, - b: usize, - ) -> E { - self.raw_dist.distance_at_indices(a, b) + fn num_ndpoints(&self) -> usize { + self.projected_elements.num_ndpoints() } } -impl<'a, 'b: 'a, const N: usize, D> DBSCANRunner<'a, N, D> +impl<'c, 'b: 'c, 'a: 'b, const N: usize, D, FF> DBSCANRunner<'b, N, D, FF> where - D: Sync, + D: Sync + Send + 'a, + FF: Fn(&D) -> bool + Send + Sync + 'a + ?Sized, { fn run( &self, - raw_elements: &'b PP, // Vec, // trait impl Index + raw_elements: &'a PP, // Vec, // trait impl Index intensity_sorted_indices: Vec<(usize, u64)>, - indexed_points: &'b QIP, - projected_elements: &'b PE, //[NDPoint], // trait impl AsNDPointAtIndex> - raw_distance_calculator: &'b DAI, + indexed_points: &'a QIP, + projected_elements: &'a PE, //[NDPoint], // trait impl AsNDPointAtIndex> + raw_distance_calculator: &'a DAI, ) -> ClusterLabels where PP: IntenseAtIndex + Send + Sync + ?Sized, - PE: AsNDPointsAtIndex + ?Sized, + PE: AsNDPointsAtIndex + Send + Sync + ?Sized, DAI: DistantAtIndex + Send + Sync + ?Sized, - QIP: QueriableIndexedPoints<'a, N> + std::marker::Sync + std::fmt::Debug, + QIP: QueriableIndexedPoints + std::marker::Sync + std::fmt::Debug, { + if self.progress { + debug!("Starting DBSCAN"); + debug!("Params: {:?}", self); + } let usize_filterfun = match self.filter_fun { Some(filterfun) => { let cl = |a: &usize, b: &usize| { @@ -306,7 +396,7 @@ where _phantom_metric: PhantomData, }; // Q: if filter fun is required ... why is it an option? - state = self.process_points(state, &points); + state = self.process_points(state, Arc::new(points)); state = self.report_timers(state); self.take_cluster_labels(state) @@ -325,19 +415,26 @@ where &self, state: DBSCANRunnerState, ) -> ClusterLabels { + if self.progress { + debug!("Finished DBSCAN"); + debug!( + "Exporting Num clusters: {}", + state.cluster_labels.num_clusters + ); + } state.cluster_labels } fn process_points( &self, mut state: DBSCANRunnerState, - points: &DBSCANPoints<'a, N, PP, PE, DAI, D, QIP>, + points: Arc>, ) -> DBSCANRunnerState where - PP: IntenseAtIndex + Send + ?Sized, - PE: AsNDPointsAtIndex + ?Sized, + PP: IntenseAtIndex + Send + Sync + ?Sized, + PE: AsNDPointsAtIndex + Sync + Send + ?Sized, DAI: DistantAtIndex + Send + Sync + ?Sized, - QIP: QueriableIndexedPoints<'a, N> + std::marker::Sync, + QIP: QueriableIndexedPoints + std::marker::Sync, { let my_progbar = state.create_progress_bar(points.intensity_sorted_indices.len(), self.progress); @@ -349,7 +446,7 @@ where { self.process_single_point( *point_index, - &points, + Arc::clone(&points), &mut state.cluster_labels, &mut state.filter_fun_cache, &mut state.timers, @@ -361,18 +458,18 @@ where /// This method gets applied to every point in decreasing intensity order. fn process_single_point( - &self, + &'b self, point_index: usize, - points: &DBSCANPoints<'a, N, PP, PE, DAI, D, QIP>, + points: Arc>, cluster_labels: &mut ClusterLabels, filter_fun_cache: &mut Option, timers: &mut DBScanTimers, cc_metrics: &mut CandidateCountMetrics, ) where - PP: IntenseAtIndex + Send + ?Sized, - PE: AsNDPointsAtIndex + ?Sized, + PP: IntenseAtIndex + Send + Sync + ?Sized, + PE: AsNDPointsAtIndex + Send + Sync + ?Sized, DAI: DistantAtIndex + Send + Sync + ?Sized, - QIP: QueriableIndexedPoints<'a, N> + std::marker::Sync, + QIP: QueriableIndexedPoints + Sync, { if cluster_labels.get(point_index) != ClusterLabel::Unassigned { return; @@ -380,7 +477,7 @@ where let neighbors = self.find_main_loop_neighbors( point_index, - points, + Arc::clone(&points), filter_fun_cache, timers, cc_metrics, @@ -388,7 +485,7 @@ where // trace!("Neighbors: {:?}", neighbors); - if !self.is_core_point(&neighbors, points.raw_elements, timers) { + if !self.is_core_point(&neighbors, Arc::clone(&points), timers) { cluster_labels.set_noise(point_index); return; } @@ -403,43 +500,49 @@ where ); } - fn find_main_loop_neighbors( + fn find_main_loop_neighbors( &self, point_index: usize, - points: &DBSCANPoints<'a, N, PP, PE, DAI, D, QIP>, + points: Arc, filter_fun_cache: &mut Option, timers: &mut DBScanTimers, cc_metrics: &mut CandidateCountMetrics, ) -> Vec where - PP: IntenseAtIndex + Send + ?Sized, - PE: AsNDPointsAtIndex + ?Sized, - DAI: DistantAtIndex + Send + Sync + ?Sized, - QIP: QueriableIndexedPoints<'a, N> + std::marker::Sync, + PTS: AsNDPointsAtIndex + + DistantAtIndex + + QueriableIndexedPoints + + IntenseAtIndex + + Send + + Sync + + ?Sized, { timers.outer_loop_nn_timer.reset_start(); - let binding = points.projected_elements.get_ndpoint(point_index); - let query_elems = convert_to_bounds_query(&binding); - let mut candidate_neighbors = points - .indexed_points - .query_ndrange(&query_elems.0, query_elems.1); + let binding = points.get_ndpoint(point_index); + let mut candidate_neighbors = points.query_ndpoint(&binding); + // Every point should have at least itself as a neighbor. + debug_assert!( + !candidate_neighbors.is_empty(), + "No neighbors found, {}, {:?}, at least itself should be a neighbor", + point_index, + binding + ); // trace!("Query elems: {:?}", query_elems); // trace!("Candidate neighbors: {:?}", candidate_neighbors); if cfg!(debug_assertions) { + let max_i = candidate_neighbors.iter().max().unwrap(); // Make sure all generated neighbors are within the bounds. - for i in candidate_neighbors.iter() { - assert!( - *i < points.projected_elements.num_ndpoints(), - "Index: {} out of proj elems bounds", - i - ); - assert!( - *i < points.raw_elements.intensity_index_length(), - "Index: {} out of intensity bounds", - i - ); - } + assert!( + *max_i < points.num_ndpoints(), + "Index: {} out of proj elems bounds", + max_i, + ); + assert!( + *max_i < points.intensity_index_length(), + "Index: {} out of intensity bounds", + max_i + ); } timers.outer_loop_nn_timer.stop(false); @@ -454,9 +557,8 @@ where match res_in_cache { Some(res) => res, None => { - let res = (self.filter_fun.unwrap())( - &points.get_distance_at_indices(*i, point_index), - ); + let res = + (self.filter_fun.unwrap())(&points.distance_at_indices(*i, point_index)); tmp.set(*i, point_index, res); res }, @@ -478,7 +580,7 @@ where fn is_core_point( &self, neighbors: &[usize], - raw_elements: &'a PP, + points: Arc, timers: &mut DBScanTimers, ) -> bool where @@ -487,7 +589,7 @@ where timers.outer_intensity_calculation.reset_start(); let neighbor_intensity_total = neighbors .iter() - .map(|i| raw_elements.intensity_at_index(*i)) + .map(|i| points.intensity_at_index(*i)) .sum::(); timers.outer_intensity_calculation.stop(false); return neighbor_intensity_total >= self.min_intensity; @@ -497,15 +599,15 @@ where &self, apex_point_index: usize, neighbors: Vec, - points: &DBSCANPoints<'a, N, PP, PE, DAI, D, QIP>, + points: Arc>, cluster_labels: &mut ClusterLabels, filter_fun_cache: &mut Option, timers: &mut DBScanTimers, ) where - PP: IntenseAtIndex + Send + ?Sized, - PE: AsNDPointsAtIndex + ?Sized, + PP: IntenseAtIndex + Sync + Send + ?Sized, + PE: AsNDPointsAtIndex + Send + Sync + ?Sized, DAI: DistantAtIndex + Send + Sync + ?Sized, - QIP: QueriableIndexedPoints<'a, N> + std::marker::Sync, + QIP: QueriableIndexedPoints + std::marker::Sync, { cluster_labels.set_new_cluster(apex_point_index); let mut seed_set: Vec = neighbors; @@ -515,12 +617,13 @@ where continue; } - let local_neighbors = self.find_local_neighbors(neighbor_index, points, timers); + let local_neighbors = + self.find_local_neighbors(neighbor_index, Arc::clone(&points), timers); let filtered_neighbors = self.filter_neighbors_inner_loop( local_neighbors, apex_point_index, neighbor_index, - points, + Arc::clone(&points), cluster_labels, filter_fun_cache, timers, @@ -548,113 +651,108 @@ where } } - fn find_local_neighbors( + fn find_local_neighbors( &self, neighbor_index: usize, - points: &DBSCANPoints<'a, N, PP, PE, DAI, D, QIP>, + points: Arc, timers: &mut DBScanTimers, ) -> Vec where - PP: IntenseAtIndex + Send + ?Sized, - PE: AsNDPointsAtIndex + ?Sized, - DAI: DistantAtIndex + Send + Sync + ?Sized, - QIP: QueriableIndexedPoints<'a, N> + std::marker::Sync, + PTS: AsNDPointsAtIndex + ?Sized + QueriableIndexedPoints + std::marker::Sync + 'a, { timers.inner_loop_nn_timer.reset_start(); - let binding = points.projected_elements.get_ndpoint(neighbor_index); - let inner_query_elems = convert_to_bounds_query(&binding); + let binding = Arc::clone(&points).get_ndpoint(neighbor_index); let local_neighbors: Vec = points - .indexed_points - .query_ndrange(&inner_query_elems.0, inner_query_elems.1) - .iter_mut() + .query_ndpoint(&binding) + .iter() .map(|x| *x) .collect::>(); + // Should I warn if nothing is gotten here? + // every point should have at least itself as a neighbor ... + debug_assert!(!local_neighbors.is_empty()); timers.inner_loop_nn_timer.stop(false); local_neighbors } - fn filter_neighbors_inner_loop( + fn filter_neighbors_inner_loop( &self, local_neighbors: Vec, cluster_apex_point_index: usize, current_center_point_index: usize, - points: &DBSCANPoints<'a, N, PP, PE, DAI, D, QIP>, + points: Arc, cluster_labels: &ClusterLabels, filter_fun_cache: &mut Option, timers: &mut DBScanTimers, ) -> Vec where - PP: IntenseAtIndex + Send + ?Sized, - PE: AsNDPointsAtIndex + ?Sized, - DAI: DistantAtIndex + Send + Sync + ?Sized, - QIP: QueriableIndexedPoints<'a, N> + std::marker::Sync, + PTS: + IntenseAtIndex + Send + AsNDPointsAtIndex + DistantAtIndex + Send + Sync + ?Sized, { let filtered = self.apply_filter_fun( local_neighbors, cluster_apex_point_index, - points, + Arc::clone(&points), filter_fun_cache, ); - if !self.is_extension_core_point(&filtered, current_center_point_index, points, timers) { + if !self.is_extension_core_point( + &filtered, + current_center_point_index, + Arc::clone(&points), + timers, + ) { return Vec::new(); } let unassigned = self.filter_unassigned(filtered, cluster_labels); - let unassigned_in_global_distance = - self.filter_by_apex_distance(unassigned, cluster_apex_point_index, points, timers); + let unassigned_in_global_distance = self.filter_by_apex_distance( + unassigned, + cluster_apex_point_index, + Arc::clone(&points), + timers, + ); self.filter_by_local_intensity_and_distance( unassigned_in_global_distance, current_center_point_index, - points, + Arc::clone(&points), timers, ) } - fn filter_by_apex_distance( + fn filter_by_apex_distance( &self, mut neighbors: Vec, apex_point_index: usize, - points: &DBSCANPoints<'a, N, PP, PE, DAI, D, QIP>, + points: Arc, timers: &mut DBScanTimers, ) -> Vec where - PP: IntenseAtIndex + Send + ?Sized, - PE: AsNDPointsAtIndex + ?Sized, - DAI: DistantAtIndex + Send + Sync + ?Sized, - QIP: QueriableIndexedPoints<'a, N> + std::marker::Sync, + PTS: AsNDPointsAtIndex + ?Sized, { timers.local_neighbor_filter_timer.reset_start(); - let query_point = &points.projected_elements.get_ndpoint(apex_point_index); - neighbors.retain(|&i| { - self.is_within_max_distance(&points.projected_elements.get_ndpoint(i), query_point) - }); + let query_point = &points.get_ndpoint(apex_point_index); + neighbors.retain(|&i| self.is_within_max_distance(&points.get_ndpoint(i), query_point)); timers.local_neighbor_filter_timer.stop(false); neighbors } - fn is_extension_core_point( + fn is_extension_core_point( &self, neighbors: &[usize], current_center_point_index: usize, - points: &DBSCANPoints<'a, N, PP, PE, DAI, D, QIP>, + points: Arc, timers: &mut DBScanTimers, ) -> bool where - PP: IntenseAtIndex + Send + ?Sized, - PE: AsNDPointsAtIndex + ?Sized, - DAI: DistantAtIndex + Send + Sync + ?Sized, - QIP: QueriableIndexedPoints<'a, N> + std::marker::Sync, + PTS: IntenseAtIndex + Sync + Send + ?Sized, { timers.inner_intensity_calculation.reset_start(); let mut neighbor_intensity_total: u64 = neighbors .iter() - .map(|&i| points.raw_elements.intensity_at_index(i)) + .map(|&i| points.intensity_at_index(i)) .sum(); - neighbor_intensity_total += points - .raw_elements - .intensity_at_index(current_center_point_index); + neighbor_intensity_total += points.intensity_at_index(current_center_point_index); timers.inner_intensity_calculation.stop(false); neighbors.len() >= self.min_n && neighbor_intensity_total >= self.min_intensity @@ -668,27 +766,23 @@ where /// one could pass a function that checks if the chromatograms a high correlation. /// Because two might share the same point in space, intensity is not really /// relevant but co-elution might be critical. - fn apply_filter_fun( + fn apply_filter_fun( &self, local_neighbors: Vec, point_index: usize, - points: &DBSCANPoints<'a, N, PP, PE, DAI, D, QIP>, + points: Arc, filter_fun_cache: &mut Option, ) -> Vec where - PP: IntenseAtIndex + Send + ?Sized, - PE: AsNDPointsAtIndex + ?Sized, - DAI: DistantAtIndex + Send + Sync + ?Sized, - QIP: QueriableIndexedPoints<'a, N> + std::marker::Sync, + PTS: DistantAtIndex + IntenseAtIndex + Sync + Send + ?Sized, { if let Some(cache) = filter_fun_cache { local_neighbors .into_iter() .filter(|&i| { cache.get(i, point_index).unwrap_or_else(|| { - let res = (self.filter_fun.unwrap())( - &points.get_distance_at_indices(i, point_index), - ); + let res = + (self.filter_fun.unwrap())(&points.distance_at_indices(i, point_index)); cache.set(i, point_index, res); res }) @@ -708,27 +802,23 @@ where neighbors } - fn filter_by_local_intensity_and_distance( + fn filter_by_local_intensity_and_distance( &self, mut neighbors: Vec, neighbor_index: usize, - points: &DBSCANPoints<'a, N, PP, PE, DAI, D, QIP>, + points: Arc, timers: &mut DBScanTimers, ) -> Vec where - PP: IntenseAtIndex + Send + ?Sized, - PE: AsNDPointsAtIndex + ?Sized, - DAI: DistantAtIndex + Send + Sync + ?Sized, - QIP: QueriableIndexedPoints<'a, N> + std::marker::Sync, + PTS: IntenseAtIndex + AsNDPointsAtIndex + Sync + Send + ?Sized, { timers.local_neighbor_filter_timer.reset_start(); - let query_intensity = points.raw_elements.intensity_at_index(neighbor_index); - let query_point = &points.projected_elements.get_ndpoint(neighbor_index); + let query_intensity = points.intensity_at_index(neighbor_index); + let query_point = &points.get_ndpoint(neighbor_index); neighbors.retain(|&i| { - let going_downhill = points.raw_elements.intensity_at_index(i) <= query_intensity; - let within_distance = - self.is_within_max_distance(&points.projected_elements.get_ndpoint(i), query_point); + let going_downhill = points.intensity_at_index(i) <= query_intensity; + let within_distance = self.is_within_max_distance(&points.get_ndpoint(i), query_point); going_downhill && within_distance }); @@ -753,10 +843,11 @@ pub fn dbscan_label_clusters< 'a, const N: usize, RE: IntenseAtIndex + DistantAtIndex + Send + Sync + AsAggregableAtIndex + ?Sized, - T: QueriableIndexedPoints<'a, N> + Send + std::marker::Sync + std::fmt::Debug, + T: QueriableIndexedPoints + Send + std::marker::Sync + std::fmt::Debug, PE: AsNDPointsAtIndex + Send + Sync + ?Sized, D: Send + Sync, E: Send + Sync + Copy, + FF: Fn(&D) -> bool + Send + Sync + ?Sized, >( indexed_points: &'a T, raw_elements: &'a RE, @@ -764,7 +855,7 @@ pub fn dbscan_label_clusters< min_n: usize, min_intensity: u64, intensity_sorted_indices: Vec<(usize, u64)>, - filter_fun: Option<&'a (dyn Fn(&D) -> bool + Send + Sync)>, + filter_fun: Option<&'a FF>, progress: bool, max_extension_distances: &'a [f32; N], ) -> ClusterLabels { @@ -774,6 +865,7 @@ pub fn dbscan_label_clusters< progress, filter_fun: filter_fun, max_extension_distances, + _phantom: PhantomData::, }; let cluster_labels = runner.run( diff --git a/src/aggregation/ms_denoise.rs b/src/aggregation/ms_denoise.rs index adfc542..763d67e 100644 --- a/src/aggregation/ms_denoise.rs +++ b/src/aggregation/ms_denoise.rs @@ -1,3 +1,4 @@ +use core::fmt::Debug; use core::panic; use serde::{Deserialize, Serialize}; @@ -19,7 +20,7 @@ use crate::utils; use crate::utils::maybe_save_json_if_debugging; use indicatif::ParallelProgressIterator; -use log::{info, trace, warn}; +use log::{debug, info, trace, warn}; use rayon::prelude::*; use timsrust::Frame; @@ -163,20 +164,6 @@ fn denoise_frame_slice_window( let ref_frame_parent_index = fsw.window[fsw.reference_index].parent_frame_index; let saved_first = maybe_save_json_if_debugging(&fsw, &*format!("fsw_{}", ref_frame_parent_index), false); - // dbscan_aggregate( - // &fsw, - // &fsw, - // &fsw, - // timer, - // min_n, - // min_intensity, - // TimsPeakAggregator::default, - // None::<&(dyn Fn(&f32) -> bool + Send + Sync)>, - // utils::LogLevel::TRACE, - // false, - // &[max_mz_extension as f32, max_ims_extension], - // false, - // ); let mut intensity_sorted_indices = Vec::with_capacity(fsw.num_ndpoints()); for i in 0..fsw.num_ndpoints() { @@ -184,7 +171,16 @@ fn denoise_frame_slice_window( let intensity = fsw.intensity_at_index(i); intensity_sorted_indices.push((i, intensity)); } - intensity_sorted_indices.par_sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + intensity_sorted_indices.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + if cfg!(debug_assertions) { + // I know this should be obviously always true, but I dont trust myself + // and thinking about orderings. + let mut last_intensity = u64::MAX; + for (_i, intensity) in intensity_sorted_indices.iter() { + assert!(*intensity <= last_intensity); + last_intensity = *intensity; + } + } let mut i_timer = timer.start_sub_timer("dbscan"); let cluster_labels = dbscan_label_clusters( @@ -329,7 +325,7 @@ where where Self: Sync, { - info!("Denoising {} frames", elems.len()); + debug!("Denoising {} frames", elems.len()); // randomly viz 1/200 frames // Selecting a slice of 1/200 frames @@ -417,7 +413,7 @@ impl<'a> Denoiser<'a, Frame, Vec, Converters, Option> where Self: Sync, { - info!("Denoising {} frames", elems.len()); + info!("Denoising (centroiding) {} frames", elems.len()); let mut frame_window_slices = self.dia_frame_info.split_frame_windows(&elems); @@ -430,6 +426,12 @@ impl<'a> Denoiser<'a, Frame, Vec, Converters, Option> frame_window_slices.truncate(num_windows); } + // This warning reders to denoise_frame_slice_window. + // to have them be not hard-coded I need a way to convert + // m/z space ranges to tof indices ... which is not exposed + // by timsrust ... + warn!("Using prototype function for denoising, scalings are hard-coded"); + let mut out = Vec::with_capacity(frame_window_slices.len()); let num_windows = frame_window_slices.len(); for (i, sv) in frame_window_slices.iter().enumerate() { @@ -470,16 +472,16 @@ impl<'a> Denoiser<'a, Frame, Vec, Converters, Option> .collect::>() }; - info!("Denoised {} frames", denoised_elements.len()); + debug!("Denoised {} frames", denoised_elements.len()); denoised_elements .retain(|x| x.frame.raw_peaks.iter().map(|y| y.intensity).sum::() > 20); - info!("Retained {} frames", denoised_elements.len()); + debug!("Retained {} frames", denoised_elements.len()); let end_tot_peaks = denoised_elements .iter() .map(|x| x.frame.raw_peaks.len() as u64) .sum::(); let ratio = end_tot_peaks as f64 / start_tot_peaks as f64; - info!( + debug!( "Start peaks: {}, End peaks: {} -> ratio: {:.2}", start_tot_peaks, end_tot_peaks, ratio ); diff --git a/src/aggregation/tracing.rs b/src/aggregation/tracing.rs index 74c8e61..1ab9d68 100644 --- a/src/aggregation/tracing.rs +++ b/src/aggregation/tracing.rs @@ -1,15 +1,17 @@ -use crate::aggregation::aggregators::ClusterAggregator; +use crate::aggregation::aggregators::{aggregate_clusters, ClusterAggregator}; use crate::aggregation::chromatograms::{ BTreeChromatogram, ChromatogramArray, NUM_LOCAL_CHROMATOGRAM_BINS, }; use crate::aggregation::dbscan::dbscan::dbscan_generic; +use crate::aggregation::dbscan::runner::dbscan_label_clusters; use crate::ms::frames::DenseFrameWindow; use crate::space::space_generics::{ - AsAggregableAtIndex, DistantAtIndex, HasIntensity, NDPoint, NDPointConverter, TraceLike, + AsAggregableAtIndex, AsNDPointsAtIndex, DistantAtIndex, HasIntensity, NDPoint, + NDPointConverter, QueriableIndexedPoints, TraceLike, }; use crate::space::space_generics::{IntenseAtIndex, NDBoundary}; use crate::utils; -use crate::utils::RollingSDCalculator; +use crate::utils::{binary_search_slice, RollingSDCalculator}; use core::panic; use log::{debug, error, info, warn}; @@ -227,45 +229,34 @@ pub fn combine_traces( .map(_flatten_denseframe_vec) .collect(); + let combine_lambda = |x: Vec| { + combine_single_window_traces2( + x, + config.mz_scaling.into(), + config.max_mz_expansion_ratio, + config.rt_scaling.into(), + config.max_rt_expansion_ratio, + config.ims_scaling.into(), + config.max_ims_expansion_ratio, + config.min_n.into(), + config.min_neighbor_intensity, + rt_binsize, + ) + }; + // Combine the traces let out: Vec = if cfg!(feature = "less_parallel") { warn!("Running in single-threaded mode"); grouped_windows .into_iter() - .map(|x| { - combine_single_window_traces( - x, - config.mz_scaling.into(), - config.max_mz_expansion_ratio, - config.rt_scaling.into(), - config.max_rt_expansion_ratio, - config.ims_scaling.into(), - config.max_ims_expansion_ratio, - config.min_n.into(), - config.min_neighbor_intensity, - rt_binsize, - ) - }) + .map(combine_lambda) .flatten() .collect() } else { grouped_windows - .into_par_iter() - .map(|x| { - combine_single_window_traces( - x, - config.mz_scaling.into(), - config.max_mz_expansion_ratio, - config.rt_scaling.into(), - config.max_rt_expansion_ratio, - config.ims_scaling.into(), - config.max_ims_expansion_ratio, - config.min_n.into(), - config.min_neighbor_intensity, - rt_binsize, - ) - }) - .flatten() + .into_par_iter() + .map(combine_lambda) + .flatten() .collect() }; @@ -465,6 +456,313 @@ impl DistantAtIndex for Vec { // Needed to specify the generic in dbscan_generic type FFTimeTimsPeak = fn(&TimeTimsPeak, &TimeTimsPeak) -> bool; +#[derive(Debug)] +struct TimeTimsPeakScaling { + mz_scaling: f32, + rt_scaling: f32, + ims_scaling: f32, + quad_scaling: f32, +} + +#[derive(Debug)] +struct QueriableTimeTimsPeaks { + peaks: Vec, + min_bucket_mz_vals: Vec, + bucket_size: usize, + scalings: TimeTimsPeakScaling, +} + +impl QueriableTimeTimsPeaks { + fn new( + mut peaks: Vec, + scalings: TimeTimsPeakScaling, + ) -> Self { + const BUCKET_SIZE: usize = 16384; + // // Sort all of our theoretical fragments by m/z, from low to high + peaks.par_sort_unstable_by(|a, b| a.mz.partial_cmp(&b.mz).unwrap()); + + let mut min_bucket_mz_vals = peaks + .par_chunks_mut(BUCKET_SIZE) + .map(|bucket| { + let min = bucket[0].mz; + bucket.par_sort_unstable_by(|a, b| a.rt.partial_cmp(&b.rt).unwrap()); + min as f32 + }) + .collect::>(); + + // Get the max value of the last bucket + let max_bucket_mz = peaks[peaks.len().saturating_sub(BUCKET_SIZE)..peaks.len()] + .iter() + .max_by(|a, b| a.mz.partial_cmp(&b.mz).unwrap()) + .unwrap() + .mz as f32; + min_bucket_mz_vals.push(max_bucket_mz); + + QueriableTimeTimsPeaks { + peaks, + min_bucket_mz_vals, + bucket_size: BUCKET_SIZE, + scalings, + } + } + + fn get_bucket_at( + &self, + index: usize, + ) -> Result<&[TimeTimsPeak], ()> { + let page_start = index * self.bucket_size; + if page_start >= self.peaks.len() { + return Err(()); + } + let page_end = (page_start + self.bucket_size).min(self.peaks.len()); + let tmp = &self.peaks[page_start..page_end]; + + if cfg!(debug_assertions) { + // Make sure all rts are sorted within the bucket + for i in 1..tmp.len() { + if tmp[i - 1].rt > tmp[i].rt { + panic!("RTs are not sorted within the bucket"); + } + } + } + Ok(tmp) + } + + fn get_intensity_sorted_indices(&self) -> Vec<(usize, u64)> { + let mut indices: Vec<(usize, u64)> = (0..self.peaks.len()) + .map(|i| (i, self.peaks[i].intensity)) + .collect(); + indices.par_sort_unstable_by_key(|&x| x.1); + + debug_assert!(indices.len() == self.peaks.len()); + if cfg!(debug_assertions) { + if indices.len() > 1 { + for i in 1..indices.len() { + if indices[i - 1].1 > indices[i].1 { + panic!("Indices are not sorted"); + } + } + } + } + indices + } +} + +impl AsNDPointsAtIndex<3> for QueriableTimeTimsPeaks { + fn get_ndpoint( + &self, + index: usize, + ) -> NDPoint<3> { + NDPoint { + values: [ + self.peaks[index].mz as f32, + self.peaks[index].rt, + self.peaks[index].ims, + ], + } + } + + fn num_ndpoints(&self) -> usize { + self.peaks.len() + } +} + +impl IntenseAtIndex for QueriableTimeTimsPeaks { + fn intensity_at_index( + &self, + index: usize, + ) -> u64 { + self.peaks[index].intensity + } + + fn intensity_index_length(&self) -> usize { + self.peaks.len() + } +} + +impl AsAggregableAtIndex for QueriableTimeTimsPeaks { + fn get_aggregable_at_index( + &self, + index: usize, + ) -> TimeTimsPeak { + self.peaks[index] + } + + fn num_aggregable(&self) -> usize { + self.peaks.len() + } +} + +impl DistantAtIndex for QueriableTimeTimsPeaks { + fn distance_at_indices( + &self, + index: usize, + other: usize, + ) -> f32 { + let a = self.peaks[index]; + let b = self.peaks[other]; + let mz = (a.mz - b.mz) as f32 / self.scalings.mz_scaling; + let rt = (a.rt - b.rt) as f32 / self.scalings.rt_scaling; + let ims = (a.ims - b.ims) as f32 / self.scalings.ims_scaling; + (mz * mz + rt * rt + ims * ims).sqrt() + } +} + +impl QueriableIndexedPoints<3> for QueriableTimeTimsPeaks { + fn query_ndpoint( + &self, + point: &NDPoint<3>, + ) -> Vec { + let boundary = NDBoundary::new( + [ + (point.values[0] - self.scalings.mz_scaling) - f32::EPSILON, + (point.values[1] - self.scalings.rt_scaling), + (point.values[2] - self.scalings.ims_scaling) - f32::EPSILON, + ], + [ + (point.values[0] + self.scalings.mz_scaling) + f32::EPSILON, + (point.values[1] + self.scalings.rt_scaling), + (point.values[2] + self.scalings.ims_scaling) + f32::EPSILON, + ], + ); + let out = self.query_ndrange(&boundary, None); + out + } + + fn query_ndrange( + &self, + boundary: &NDBoundary<3>, + reference_point: Option<&NDPoint<3>>, + ) -> Vec { + let mut out = Vec::new(); + let mz_range = (boundary.starts[0], boundary.ends[0]); + let mz_range_f64 = (boundary.starts[0] as f64, boundary.ends[0] as f64); + let rt_range = (boundary.starts[1], boundary.ends[1]); + let ims_range = (boundary.starts[2], boundary.ends[2]); + + let (bstart, bend) = binary_search_slice( + &self.min_bucket_mz_vals, + |a, b| a.total_cmp(b), + mz_range.0, + mz_range.1, + ); + + let bstart = bstart.saturating_sub(1); + let bend_new = bend.saturating_add(1).min(self.min_bucket_mz_vals.len()); + + for bnum in bstart..bend_new { + let c_bucket = self.get_bucket_at(bnum); + if c_bucket.is_err() { + continue; + } + let c_bucket = c_bucket.unwrap(); + let page_start = bnum * self.bucket_size; + + let (istart, iend) = + binary_search_slice(c_bucket, |a, b| a.rt.total_cmp(&b), rt_range.0, rt_range.1); + + for (j, peak) in self.peaks[(istart + page_start)..(iend + page_start)] + .iter() + .enumerate() + { + debug_assert!( + peak.rt >= rt_range.0 && peak.rt <= rt_range.1, + "RT out of range -> {} {} {}; istart {}, page_starrt {}, j {}; window rts: {:?}", + peak.rt, + rt_range.0, + rt_range.1, + istart, + page_start, + j, + &self.peaks[(j + istart + page_start).saturating_sub(5) + ..(j + istart + page_start + 5).min(self.peaks.len())] + .iter() + .map(|x| x.rt) + .collect::>() + ); + if peak.ims >= ims_range.0 && peak.ims <= ims_range.1 { + if peak.mz as f32 >= mz_range.0 && peak.mz as f32 <= mz_range.1 { + out.push(j + istart + page_start); + } + } + } + } + + out + } +} + +// QueriableIndexedPoints + +fn combine_single_window_traces2( + prefiltered_peaks: Vec, + mz_scaling: f64, + max_mz_expansion_ratio: f32, + rt_scaling: f64, + max_rt_expansion_ratio: f32, + ims_scaling: f64, + max_ims_expansion_ratio: f32, + min_n: usize, + min_intensity: u32, + rt_binsize: f32, +) -> Vec { + let timer = utils::ContextTimer::new("dbscan_wt2", true, utils::LogLevel::DEBUG); + info!("Peaks in window: {}", prefiltered_peaks.len()); + let scalings = TimeTimsPeakScaling { + mz_scaling: mz_scaling as f32, + rt_scaling: rt_scaling as f32, + ims_scaling: ims_scaling as f32, + quad_scaling: 1., + }; + let window_quad_low_high = ( + prefiltered_peaks[0].quad_low_high.0, + prefiltered_peaks[0].quad_low_high.1, + ); + let index = QueriableTimeTimsPeaks::new(prefiltered_peaks, scalings); + let intensity_sorted_indices = index.get_intensity_sorted_indices(); + let max_extension_distances: [f32; 3] = [ + max_mz_expansion_ratio * mz_scaling as f32, + max_rt_expansion_ratio * rt_scaling as f32, + max_ims_expansion_ratio * ims_scaling as f32, + ]; + + let mut i_timer = timer.start_sub_timer("dbscan"); + let cluster_labels = dbscan_label_clusters( + &index, + &index, + &index, + min_n, + min_intensity.into(), + intensity_sorted_indices, + None::<&(dyn Fn(&f32) -> bool + Send + Sync)>, + true, + &max_extension_distances, + ); + + i_timer.stop(true); + + let centroids = aggregate_clusters( + cluster_labels.num_clusters, + cluster_labels.cluster_labels, + &index, + &|| TraceAggregator { + mz: RollingSDCalculator::default(), + intensity: 0, + rt: RollingSDCalculator::default(), + ims: RollingSDCalculator::default(), + num_peaks: 0, + num_rt_peaks: 0, + quad_low_high: window_quad_low_high, + btree_chromatogram: BTreeChromatogram::new_lazy(rt_binsize), + }, + utils::LogLevel::TRACE, + false, + ); + + debug!("Combined traces: {}", centroids.len()); + centroids +} + // TODO maybe this can be a builder-> executor pattern fn combine_single_window_traces( prefiltered_peaks: Vec, diff --git a/src/main.rs b/src/main.rs index 8ae427a..105cb54 100644 --- a/src/main.rs +++ b/src/main.rs @@ -146,7 +146,7 @@ fn main() { } println!("traces: {:?}", traces.len()); - traces.retain(|x| x.num_agg > 5); + traces.retain(|x| x.num_agg > 3); println!("traces: {:?}", traces.len()); if traces.len() > 5 { println!("sample_trace: {:?}", traces[traces.len() - 4]) diff --git a/src/ms/frames/frame_slice.rs b/src/ms/frames/frame_slice.rs index 0dd98df..d9e1d89 100644 --- a/src/ms/frames/frame_slice.rs +++ b/src/ms/frames/frame_slice.rs @@ -719,7 +719,7 @@ impl<'a> AsNDPointsAtIndex<2> for FrameSlice<'a> { } } -impl QueriableIndexedPoints<'_, 2> for ExpandedFrameSlice { +impl QueriableIndexedPoints<2> for ExpandedFrameSlice { fn query_ndpoint( &self, point: &NDPoint<2>, @@ -760,9 +760,9 @@ impl QueriableIndexedPoints<'_, 2> for ExpandedFrameSlice { } } -impl<'a> QueriableIndexedPoints<'a, 2> for FrameSlice<'a> { +impl<'a> QueriableIndexedPoints<2> for FrameSlice<'a> { fn query_ndpoint( - &'a self, + &self, point: &NDPoint<2>, ) -> Vec { let tof_index = point.values[0] as i32; @@ -788,7 +788,7 @@ impl<'a> QueriableIndexedPoints<'a, 2> for FrameSlice<'a> { } fn query_ndrange( - &'a self, + &self, boundary: &NDBoundary<2>, reference_point: Option<&NDPoint<2>>, ) -> Vec { diff --git a/src/ms/frames/frame_slice_rt_window.rs b/src/ms/frames/frame_slice_rt_window.rs index 73cffb0..69565cf 100644 --- a/src/ms/frames/frame_slice_rt_window.rs +++ b/src/ms/frames/frame_slice_rt_window.rs @@ -128,9 +128,9 @@ impl IntenseAtIndex for FrameSliceWindow<'_> { } } -impl<'a> QueriableIndexedPoints<'a, 2> for FrameSliceWindow<'a> { +impl<'a> QueriableIndexedPoints<2> for FrameSliceWindow<'a> { fn query_ndpoint( - &'a self, + &self, point: &NDPoint<2>, ) -> Vec { let mut out = Vec::new(); @@ -147,7 +147,7 @@ impl<'a> QueriableIndexedPoints<'a, 2> for FrameSliceWindow<'a> { } fn query_ndrange( - &'a self, + &self, boundary: &crate::space::space_generics::NDBoundary<2>, reference_point: Option<&NDPoint<2>>, ) -> Vec { diff --git a/src/space/kdtree.rs b/src/space/kdtree.rs index 5df1beb..fa4f348 100644 --- a/src/space/kdtree.rs +++ b/src/space/kdtree.rs @@ -260,16 +260,16 @@ impl<'a, const D: usize, T> RadiusKDTree<'a, T, D> { } } -impl<'a, const D: usize> QueriableIndexedPoints<'a, D> for RadiusKDTree<'a, usize, D> { +impl<'a, const D: usize> QueriableIndexedPoints for RadiusKDTree<'a, usize, D> { fn query_ndpoint( - &'a self, + &self, point: &NDPoint, ) -> Vec { self.query(point).into_iter().map(|x| *x).collect() } fn query_ndrange( - &'a self, + &self, boundary: &NDBoundary, reference_point: Option<&NDPoint>, ) -> Vec { diff --git a/src/space/quad.rs b/src/space/quad.rs index dbb0710..6dc8851 100644 --- a/src/space/quad.rs +++ b/src/space/quad.rs @@ -255,9 +255,9 @@ impl<'a, T> RadiusQuadTree<'a, T> { // TODO: rename count_neigh_monotonocally_increasing // because it can do more than just count neighbors.... -impl<'a> QueriableIndexedPoints<'a, 2> for RadiusQuadTree<'a, usize> { +impl<'a> QueriableIndexedPoints<2> for RadiusQuadTree<'a, usize> { fn query_ndpoint( - &'a self, + &self, point: &NDPoint<2>, ) -> Vec { self.query(point) @@ -267,7 +267,7 @@ impl<'a> QueriableIndexedPoints<'a, 2> for RadiusQuadTree<'a, usize> { } fn query_ndrange( - &'a self, + &self, boundary: &NDBoundary<2>, reference_point: Option<&NDPoint<2>>, ) -> Vec { diff --git a/src/space/space_generics.rs b/src/space/space_generics.rs index d506d5b..cd806c0 100644 --- a/src/space/space_generics.rs +++ b/src/space/space_generics.rs @@ -104,13 +104,13 @@ pub struct NDPoint { pub values: [f32; DIMENSIONALITY], } -pub trait QueriableIndexedPoints<'a, const N: usize> { +pub trait QueriableIndexedPoints { fn query_ndpoint( - &'a self, + &self, point: &NDPoint, ) -> Vec; fn query_ndrange( - &'a self, + &self, boundary: &NDBoundary, reference_point: Option<&NDPoint>, ) -> Vec; diff --git a/src/utils.rs b/src/utils.rs index d25f6d5..8d8b705 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -2,6 +2,7 @@ use log::{debug, info, trace, warn}; use num::cast::AsPrimitive; use std::{ cmp::Ordering, + fmt::Debug, time::{Duration, Instant}, }; @@ -397,19 +398,41 @@ pub fn get_stats(data: &[f64]) -> Stats { } } -/// This has been shamelessly copied from sage. +/// This has been shamelessly copied and very minorly modified from sage. /// https://github.com/lazear/sage/blob/93a9a8a7c9f717238fc6c582c0dd501a56159be7/crates/sage/src/database.rs#L498 /// Althought it really feels like this should be in the standard lib. /// /// Usage: /// ```rust -/// let data = [1.0, 1.5, 1.5, 1.5, 1.5, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0]; +/// use ionmesh::utils::binary_search_slice; +/// let data: [f64; 11]= [1.0, 1.5, 1.5, 1.5, 1.5, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0]; /// let (left, right) = binary_search_slice(&data, |a: &f64, b| a.total_cmp(b), 1.5, 3.25); -/// assert!(data[left] <= 1.5); +/// assert!(data[left] == 1.5); /// assert!(data[right] > 3.25); /// assert_eq!( /// &data[left..right], -/// &[1.0, 1.5, 1.5, 1.5, 1.5, 2.0, 2.5, 3.0, 3.0] +/// &[1.5, 1.5, 1.5, 1.5, 2.0, 2.5, 3.0, 3.0] +/// ); +/// let empty: [f64; 0] = []; +/// let (left, right) = binary_search_slice(&empty, |a: &f64, b| a.total_cmp(b), 1.5, 3.25); +/// assert_eq!(left, 0); +/// assert_eq!(right, 0); +/// let (left, right) = binary_search_slice(&data, |a: &f64, b| a.total_cmp(b), -100., -99.); +/// assert_eq!(left, 0); +/// assert_eq!(right, 0); +/// assert_eq!(&data[left..right], &empty); +/// let (left, right) = binary_search_slice(&data, |a: &f64, b| a.total_cmp(b), 100., 101.); +/// assert_eq!(left, data.len()); +/// assert_eq!(right, data.len()); +/// assert_eq!(&data[left..right], &empty); +/// let data: [f64; 7]= [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0]; +/// let (left, right) = binary_search_slice(&data, |a: &f64, b| a.total_cmp(b), 1.5, 3.25); +/// assert!(data[left] == 1.5); +/// assert!(data[right] > 3.25); +/// assert!(data[right-1] < 3.25); +/// assert_eq!( +/// &data[left..right], +/// &[1.5, 2.0, 2.5, 3.0] /// ); /// ``` /// @@ -422,13 +445,23 @@ pub fn binary_search_slice( ) -> (usize, usize) where F: Fn(&T, &S) -> Ordering, + T: Debug, { let left_idx = match slice.binary_search_by(|a| key(a, &low)) { - Ok(idx) | Err(idx) => { - let mut idx = idx.saturating_sub(1); - while idx > 0 && key(&slice[idx], &low) != Ordering::Less { + Ok(mut idx) | Err(mut idx) => { + if idx == slice.len() { + // This is very non-elegant ... pretty sure I need to split + // the ok-err cases to make a more elegant solution. + return (idx, idx); + } + let mut any_nonless = false; + while idx != 0 && key(&slice[idx], &low) != Ordering::Less { + any_nonless = true; idx -= 1; } + if any_nonless { + idx = idx.saturating_add(1); + } idx }, }; @@ -442,6 +475,10 @@ where idx.min(slice.len()) }, }; + if cfg!(debug_assertions) { + // This makes sure the slice is indexable by the indices. + let _foo = &slice[left_idx..right_idx]; + }; (left_idx, right_idx) }