Skip to content

Commit

Permalink
feat(aggregators)!: Added and refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
jspaezp committed Sep 20, 2024
1 parent 5ae1d07 commit f9c13d4
Show file tree
Hide file tree
Showing 17 changed files with 935 additions and 351 deletions.
11 changes: 6 additions & 5 deletions benches/benchmark_indices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ fn build_elution_groups(raw_file_path: String) -> Vec<ElutionGroup> {
let mut rng = ChaCha8Rng::seed_from_u64(43u64);

for i in 1..NUM_ELUTION_GROUPS {
// Rand f32/64 are number from 0-1
let rt = rng.gen::<f32>() * MAX_RT;
let mobility = rng.gen::<f32>() * (MAX_MOBILITY - MIN_MOBILITY) + MIN_MOBILITY;
let mz = rng.gen::<f64>() * (MAX_MZ - MIN_MZ) + MIN_MZ;
Expand All @@ -69,7 +70,9 @@ fn build_elution_groups(raw_file_path: String) -> Vec<ElutionGroup> {
fragment_charges.push(fragment_charge);
}

let precursor_charge = rng.gen::<u8>() * 3 + 1;
// rand u8 is number from 0-255 ... which is not amazing for us ...
// let precursor_charge = rng.gen::<u8>() * 3 + 1;
let precursor_charge = 2;

out_egs.push(ElutionGroup {
id: i as u64,
Expand Down Expand Up @@ -125,11 +128,10 @@ macro_rules! add_bench_random {
)
},
|(index, query_groups, tolerance)| {
let aggregator_factory = |_id| RawPeakIntensityAggregator { intensity: 0 };
let local_lambda = |elution_group: &ElutionGroup| {
query_indexed(
&index,
&aggregator_factory,
&RawPeakIntensityAggregator::new,
&index,
&tolerance,
&elution_group,
Expand Down Expand Up @@ -157,13 +159,12 @@ macro_rules! add_bench_optim {
)
},
|(index, query_groups, tolerance)| {
let aggregator_factory = |_id| RawPeakIntensityAggregator { intensity: 0 };
let foo = query_multi_group(
&index,
&index,
&tolerance,
&query_groups,
&aggregator_factory,
&RawPeakIntensityAggregator::new,
);
black_box((|_foo| false)(foo));
},
Expand Down
275 changes: 167 additions & 108 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,63 @@ use timsquery::queriable_tims_data::queriable_tims_data::query_multi_group;
use timsquery::traits::tolerance::DefaultTolerance;
use timsquery::Aggregator;
use timsquery::{
models::aggregators::RawPeakIntensityAggregator, models::indices::raw_file_index::RawFileIndex,
models::aggregators::{
ChromatomobilogramStats, ExtractedIonChromatomobilogram, RawPeakIntensityAggregator,
RawPeakVectorAggregator,
},
models::indices::raw_file_index::RawFileIndex,
models::indices::transposed_quad_index::QuadSplittedTransposedIndex,
};

use timsquery::traits::tolerance::{MobilityTolerance, MzToleramce, QuadTolerance, RtTolerance};

use clap::{Parser, Subcommand};
use log::{debug, info};
use serde::{Deserialize, Serialize};

// Read json with tolerance settings
// Read json with elution groups
// Load index
// Query index
// Serialize results
fn main() {
env_logger::init();
let args = Args::parse();

match args.command {
Some(Commands::QueryIndex(args)) => main_query_index(args),
Some(Commands::WriteTemplate(args)) => main_write_template(args),
None => {
println!("No command provided");
}
}
}

fn main_write_template(args: WriteTemplateArgs) {
let output_path = args.output_path;
let egs = template_elution_groups(10);
let tolerance = template_tolerance_settings();

// Serialize both and write as files in the output path.
// Do pretty serialization.
let egs_json = serde_json::to_string_pretty(&egs).unwrap();
let tolerance_json = serde_json::to_string_pretty(&tolerance).unwrap();

let put_path = std::path::Path::new(&output_path);
std::fs::create_dir_all(put_path).unwrap();
println!("Writing to {}", put_path.display());
let egs_json_path = put_path.join("elution_groups.json");
let tolerance_json_path = put_path.join("tolerance_settings.json");
std::fs::write(egs_json_path.clone(), egs_json).unwrap();
std::fs::write(tolerance_json_path.clone(), tolerance_json).unwrap();
println!(
"use as `timsquery query-index --pretty --output-path '.' --raw-file-path 'your_file.d' --tolerance-settings-path {:#?} --elution-groups-path {:#?}`",
tolerance_json_path,
egs_json_path,
);
}

fn template_elution_groups(num: usize) -> Vec<ElutionGroup> {
let mut egs = Vec::with_capacity(num);
for i in 1..num {
let rt = 300.0 + (i as f32 * 10.0);
let mobility = 1.0 + (i as f32 * 0.01);
let mz = 1000.0 + (i as f64 * 10.0);
let mz = 500.0 + (i as f64 * 10.0);
let precursor_charge = 2;
let fragment_mzs = Some(vec![mz]);
let fragment_charges = Some(vec![precursor_charge]);
Expand All @@ -40,6 +76,39 @@ fn template_elution_groups(num: usize) -> Vec<ElutionGroup> {
egs
}

fn main_query_index(args: QueryIndexArgs) {
let args_clone = args.clone();

let raw_file_path = args.raw_file_path;
let tolerance_settings_path = args.tolerance_settings_path;
let elution_groups_path = args.elution_groups_path;
let index_use = args.index;
let aggregator_use = args.aggregator;

let tolerance_settings: DefaultTolerance =
serde_json::from_str(&std::fs::read_to_string(&tolerance_settings_path).unwrap()).unwrap();
let elution_groups: Vec<ElutionGroup> =
serde_json::from_str(&std::fs::read_to_string(&elution_groups_path).unwrap()).unwrap();

let index_use = match (index_use, elution_groups.len() > 10) {
(PossibleIndex::RawFileIndex, true) => PossibleIndex::RawFileIndex,
(PossibleIndex::TransposedQuadIndex, true) => PossibleIndex::TransposedQuadIndex,
(PossibleIndex::RawFileIndex, false) => PossibleIndex::RawFileIndex,
(PossibleIndex::TransposedQuadIndex, false) => PossibleIndex::TransposedQuadIndex,
(PossibleIndex::Unspecified, true) => PossibleIndex::TransposedQuadIndex,
(PossibleIndex::Unspecified, false) => PossibleIndex::RawFileIndex,
};

execute_query(
index_use,
aggregator_use,
raw_file_path,
tolerance_settings,
elution_groups,
args_clone,
);
}

fn template_tolerance_settings() -> DefaultTolerance {
DefaultTolerance {
ms: MzToleramce::Ppm((20.0, 20.0)),
Expand All @@ -57,21 +126,24 @@ struct Args {
}

#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)]
enum PossibleAggregator {
pub enum PossibleAggregator {
#[default]
RawPeakIntensityAggregator,
RawPeakVectorAggregator,
ExtractedIonChromatomobilogram,
ChromatoMobilogramStat,
}

#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)]
enum PossibleIndex {
pub enum PossibleIndex {
#[default]
Unspecified,
RawFileIndex,
TransposedQuadIndex,
}

#[derive(Parser, Debug)]
struct QueryIndexArgs {
#[derive(Parser, Debug, Clone)]
pub struct QueryIndexArgs {
/// The path to the raw file to query.
#[arg(short, long)]
raw_file_path: String,
Expand Down Expand Up @@ -116,112 +188,99 @@ enum Commands {
}

#[derive(Debug, Serialize, Deserialize)]
struct ElutionGroupResults {
struct ElutionGroupResults<T: Serialize> {
elution_group: ElutionGroup,
result: u64,
result: T,
}

fn main() {
env_logger::init();
let args = Args::parse();

match args.command {
Some(Commands::QueryIndex(args)) => main_query_index(args),
Some(Commands::WriteTemplate(args)) => main_write_template(args),
None => {
println!("No command provided");
}
}
}

fn main_write_template(args: WriteTemplateArgs) {
pub fn execute_query(
index: PossibleIndex,
aggregator: PossibleAggregator,
raw_file_path: String,
tolerance: DefaultTolerance,
elution_groups: Vec<ElutionGroup>,
args: QueryIndexArgs,
) {
let output_path = args.output_path;
let egs = template_elution_groups(10);
let tolerance = template_tolerance_settings();
let pretty_print = args.pretty;

// Serialize both and write as files in the output path.
// Do pretty serialization.
let egs_json = serde_json::to_string_pretty(&egs).unwrap();
let tolerance_json = serde_json::to_string_pretty(&tolerance).unwrap();
macro_rules! execute_query_inner {
($index:expr, $agg:expr) => {
let tmp = query_multi_group(&$index, &$index, &tolerance, &elution_groups, &$agg);
// debug!("{:?}", tmp);

let put_path = std::path::Path::new(&output_path);
std::fs::create_dir_all(put_path).unwrap();
println!("Writing to {}", put_path.display());
let egs_json_path = put_path.join("elution_groups.json");
let tolerance_json_path = put_path.join("tolerance_settings.json");
std::fs::write(egs_json_path.clone(), egs_json).unwrap();
std::fs::write(tolerance_json_path.clone(), tolerance_json).unwrap();
println!(
"use as `timsquery query-index --pretty --output-path '.' --raw-file-path 'your_file.d' --tolerance-settings-path {:#?} --elution-groups-path {:#?}`",
tolerance_json_path,
egs_json_path,
);
}
let start = std::time::Instant::now();
let mut out = Vec::with_capacity(tmp.len());
for (res, eg) in tmp.into_iter().zip(elution_groups) {
out.push(ElutionGroupResults {
elution_group: eg,
result: res.finalize(),
});
}
let elapsed = start.elapsed();
println!("Finalizing query took {:#?}", elapsed);
// info!("{:#?}", out);

fn main_query_index(args: QueryIndexArgs) {
let raw_file_path = args.raw_file_path;
let tolerance_settings_path = args.tolerance_settings_path;
let elution_groups_path = args.elution_groups_path;
let output_path = args.output_path;
let index_use = args.index;
let put_path = std::path::Path::new(&output_path).join("results.json");
std::fs::create_dir_all(put_path.parent().unwrap()).unwrap();
println!("Writing to {}", put_path.display());

let tolerance_settings: DefaultTolerance =
serde_json::from_str(&std::fs::read_to_string(&tolerance_settings_path).unwrap()).unwrap();
let elution_groups: Vec<ElutionGroup> =
serde_json::from_str(&std::fs::read_to_string(&elution_groups_path).unwrap()).unwrap();

let aggregator_factory = |_id| RawPeakIntensityAggregator { intensity: 0 };
let foo = if (elution_groups.len() > 10) || index_use == PossibleIndex::TransposedQuadIndex {
let index = QuadSplittedTransposedIndex::from_path(&(raw_file_path.clone())).unwrap();
query_multi_group(
&index,
&index,
&tolerance_settings,
&elution_groups,
&aggregator_factory,
)
} else {
let index = RawFileIndex::from_path(&(raw_file_path.clone())).unwrap();
query_multi_group(
&index,
&index,
&tolerance_settings,
&elution_groups,
&aggregator_factory,
)
};

let mut out = Vec::new();
for (res, eg) in foo.into_iter().zip(elution_groups) {
out.push(ElutionGroupResults {
elution_group: eg,
result: res.finalize(),
});
let serialized = if pretty_print {
println!("Pretty printing enabled");
serde_json::to_string_pretty(&out).unwrap()
} else {
serde_json::to_string(&out).unwrap()
};
std::fs::write(put_path, serialized).unwrap();
};
}

let put_path = std::path::Path::new(&output_path).join("results.json");
std::fs::create_dir_all(put_path.parent().unwrap()).unwrap();
println!("Writing to {}", put_path.display());
match (index, aggregator) {
(PossibleIndex::TransposedQuadIndex, aggregator) => {
let index = QuadSplittedTransposedIndex::from_path(&(raw_file_path.clone())).unwrap();
match aggregator {
PossibleAggregator::RawPeakIntensityAggregator => {
let aggregator = RawPeakIntensityAggregator::new;
execute_query_inner!(index, aggregator);
}
PossibleAggregator::RawPeakVectorAggregator => {
let aggregator = RawPeakVectorAggregator::new;
execute_query_inner!(index, aggregator);
}
PossibleAggregator::ChromatoMobilogramStat => {
let aggregator = ChromatomobilogramStats::new;
execute_query_inner!(index, aggregator);
}
PossibleAggregator::ExtractedIonChromatomobilogram => {
let aggregator = ExtractedIonChromatomobilogram::new;
execute_query_inner!(index, aggregator);
}
}
}
(PossibleIndex::RawFileIndex, aggregator) => {
let index = RawFileIndex::from_path(&(raw_file_path.clone())).unwrap();
match aggregator {
PossibleAggregator::RawPeakIntensityAggregator => {
let aggregator = RawPeakIntensityAggregator::new;
execute_query_inner!(index, aggregator);
}
PossibleAggregator::RawPeakVectorAggregator => {
let aggregator = RawPeakVectorAggregator::new;
execute_query_inner!(index, aggregator);
}
PossibleAggregator::ChromatoMobilogramStat => {
let aggregator = ChromatomobilogramStats::new;
execute_query_inner!(index, aggregator);
}
PossibleAggregator::ExtractedIonChromatomobilogram => {
let aggregator = ExtractedIonChromatomobilogram::new;
execute_query_inner!(index, aggregator);
}
}
}

let serialized = if args.pretty {
println!("Pretty printing enabled");
serde_json::to_string_pretty(&out).unwrap()
} else {
serde_json::to_string(&out).unwrap()
};
std::fs::write(put_path, serialized).unwrap();
(PossibleIndex::Unspecified, _) => {
panic!("Should have been handled");
}
}
}

// fn main() {
// println!("Hello, world!");
// let qst_file_index = QuadSplittedTransposedIndex::from_path(&(raw_file_path.clone())).unwrap();
// let tolerance = DefaultTolerance::default();
// let aggregator_factory = |id| RawPeakIntensityAggregator { intensity: 0 };
// let foo = query_multi_group(
// &qst_file_index,
// &qst_file_index,
// &tolerance,
// &query_groups,
// &aggregator_factory,
// );
// }
4 changes: 4 additions & 0 deletions src/models/aggregators/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
pub mod raw_peak_agg;
pub mod streaming_aggregator;

pub use raw_peak_agg::ChromatomobilogramStats;
pub use raw_peak_agg::ExtractedIonChromatomobilogram;
pub use raw_peak_agg::RawPeakIntensityAggregator;
pub use raw_peak_agg::RawPeakVectorAggregator;
Loading

0 comments on commit f9c13d4

Please sign in to comment.