diff --git a/Cargo.lock b/Cargo.lock index 1fd2380c..0950ee22 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1828,6 +1828,7 @@ dependencies = [ "camino", "csv", "env_logger", + "getset", "glob", "log", "needletail 0.5.1", diff --git a/Cargo.toml b/Cargo.toml index a8a5030d..ac741ef8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ rustworkx-core = "0.15.1" streaming-stats = "0.2.3" rust_decimal = { version = "1.36.0", features = ["maths"] } rust_decimal_macros = "1.36.0" +getset = "0.1" [dev-dependencies] assert_cmd = "2.0.16" diff --git a/src/manysketch.rs b/src/manysketch.rs index 3eafa880..9920207f 100644 --- a/src/manysketch.rs +++ b/src/manysketch.rs @@ -2,49 +2,13 @@ use anyhow::{anyhow, Result}; use rayon::prelude::*; -use crate::utils::{load_fasta_fromfile, parse_params_str, sigwriter, Params}; use camino::Utf8Path as Path; use needletail::parse_fastx_file; -use sourmash::cmd::ComputeParameters; -use sourmash::signature::Signature; use std::sync::atomic; use std::sync::atomic::AtomicUsize; -pub fn build_siginfo(params: &[Params], input_moltype: &str) -> Vec { - let mut sigs = Vec::new(); - - for param in params.iter().cloned() { - match input_moltype { - // if dna, only build dna sigs. if protein, only build protein sigs, etc - "dna" | "DNA" if !param.is_dna => continue, - "protein" if !param.is_protein && !param.is_dayhoff && !param.is_hp => continue, - _ => (), - } - - // Adjust ksize value based on the is_protein flag - let adjusted_ksize = if param.is_protein || param.is_dayhoff || param.is_hp { - param.ksize * 3 - } else { - param.ksize - }; - - let cp = ComputeParameters::builder() - .ksizes(vec![adjusted_ksize]) - .scaled(param.scaled) - .protein(param.is_protein) - .dna(param.is_dna) - .dayhoff(param.is_dayhoff) - .hp(param.is_hp) - .num_hashes(param.num) - .track_abundance(param.track_abundance) - .build(); - - let sig = Signature::from_params(&cp); - sigs.push(sig); - } - - sigs -} +use crate::utils::buildutils::{BuildCollection, MultiSelect, MultiSelection}; +use crate::utils::{load_fasta_fromfile, zipwriter_handle}; pub fn manysketch( filelist: String, @@ -71,25 +35,25 @@ pub fn manysketch( bail!("Output must be a zip file."); } - // set up a multi-producer, single-consumer channel that receives Signature + // set up a multi-producer, single-consumer channel that receives BuildCollection let (send, recv) = - std::sync::mpsc::sync_channel::>>(rayon::current_num_threads()); - // need to use Arc so we can write the manifest after all sigs have written - // let send = std::sync::Arc::new(send); + std::sync::mpsc::sync_channel::>(rayon::current_num_threads()); // & spawn a thread that is dedicated to printing to a buffered output - let thrd = sigwriter(recv, output); + let thrd = zipwriter_handle(recv, output); - // parse param string into params_vec, print error if fail - let param_result = parse_params_str(param_str); - let params_vec = match param_result { - Ok(params) => params, + // params --> buildcollection + let sig_template_result = BuildCollection::from_param_str(param_str.as_str()); + let sig_templates = match sig_template_result { + Ok(sig_templates) => sig_templates, Err(e) => { - eprintln!("Error parsing params string: {}", e); - bail!("Failed to parse params string"); + bail!("Failed to parse params string: {}", e); } }; + // print sig templates to build + let _params = sig_templates.summarize_params(); + // iterate over filelist_paths let processed_fastas = AtomicUsize::new(0); let failed_paths = AtomicUsize::new(0); @@ -103,22 +67,24 @@ pub fn manysketch( .filter_map(|fastadata| { let name = &fastadata.name; let filenames = &fastadata.paths; - let moltype = &fastadata.input_type; - // build sig templates for these sketches from params, check if there are sigs to build - let sig_templates = build_siginfo(¶ms_vec, moltype); + let input_moltype = &fastadata.input_type; + let mut sigs = sig_templates.clone(); + // filter sig templates for this fasta by moltype + // atm, we only do DNA->DNA, prot->prot Future -- figure out if we need to modify to allow translate/skip + let multiselection = MultiSelection::from_input_moltype(input_moltype.as_str()) + .expect("could not build selection from input moltype"); + // todo: select should work in place?? + sigs = sigs + .select(&multiselection) + .expect("could not select on sig_templates"); + // if no sigs to build, skip this iteration - if sig_templates.is_empty() { + if sigs.is_empty() { skipped_paths.fetch_add(filenames.len(), atomic::Ordering::SeqCst); processed_fastas.fetch_add(1, atomic::Ordering::SeqCst); return None; } - let mut sigs = sig_templates.clone(); - // have name / filename been set for each sig yet? - let mut set_name = false; - // if merging multiple files, sourmash sets filename as last filename - let last_filename = filenames.last().unwrap(); - for filename in filenames { // increment processed_fastas counter; make 1-based for % reporting let i = processed_fastas.fetch_add(1, atomic::Ordering::SeqCst); @@ -132,60 +98,65 @@ pub fn manysketch( percent_processed ); } - - // Open fasta file reader - let mut reader = match parse_fastx_file(filename) { - Ok(r) => r, - Err(err) => { - eprintln!("Error opening file {}: {:?}", filename, err); - failed_paths.fetch_add(1, atomic::Ordering::SeqCst); - return None; - } - }; - - // parse fasta and add to signature - while let Some(record_result) = reader.next() { - match record_result { - Ok(record) => { - // do we need to normalize to make sure all the bases are consistently capitalized? - // let norm_seq = record.normalize(false); - sigs.iter_mut().for_each(|sig| { - if singleton { - let record_name = std::str::from_utf8(record.id()) - .expect("could not get record id"); - sig.set_name(record_name); - sig.set_filename(filename.as_str()); - } else if !set_name { - sig.set_name(name); - // sourmash sets filename to last filename if merging fastas - sig.set_filename(last_filename.as_str()); - }; - if moltype == "protein" { - sig.add_protein(&record.seq()) - .expect("Failed to add protein"); - } else { - sig.add_sequence(&record.seq(), true) - .expect("Failed to add sequence"); - // if not force, panics with 'N' in dna sequence + if singleton { + // Open fasta file reader + let mut reader = match parse_fastx_file(filename) { + Ok(r) => r, + Err(err) => { + eprintln!("Error opening file {}: {:?}", filename, err); + failed_paths.fetch_add(1, atomic::Ordering::SeqCst); + return None; + } + }; + + while let Some(record_result) = reader.next() { + match record_result { + Ok(record) => { + if let Err(err) = sigs.build_singleton_sigs( + record, + &input_moltype, + filename.to_string(), + ) { + eprintln!( + "Error building signatures from file: {}, {:?}", + filename, err + ); + // do we want to keep track of singleton sigs that fail? if so, how? } - }); - if !set_name { - set_name = true; + // send singleton sigs for writing + if let Err(e) = send.send(Some(sigs)) { + eprintln!("Unable to send internal data: {:?}", e); + return None; + } + sigs = sig_templates.clone(); } + Err(err) => eprintln!("Error while processing record: {:?}", err), } - Err(err) => eprintln!("Error while processing record: {:?}", err), } - if singleton { - // write sigs immediately to avoid memory issues - if let Err(e) = send.send(Some(sigs.clone())) { - eprintln!("Unable to send internal data: {:?}", e); - return None; + } else { + match sigs.build_sigs_from_file_or_stdin( + input_moltype, + name.clone(), + filename.to_string(), + ) { + Ok(record_count) => { + // maybe only print if verbose?? + // println!( + // "Successfully built signatures from file: {}. Records processed: {}", + // filename, record_count + // ); + } + Err(err) => { + eprintln!( + "Error building signatures from file: {}, {:?}", + filename, err + ); + failed_paths.fetch_add(1, atomic::Ordering::SeqCst); } - sigs = sig_templates.clone(); } } } - // if singleton sketches, they have already been written; only write aggregate sketches + // if singleton sketches, they have already been written; only send aggregated sketches to be written if singleton { None } else { @@ -194,7 +165,7 @@ pub fn manysketch( }) .try_for_each_with( send.clone(), - |s: &mut std::sync::mpsc::SyncSender>>, sigs| { + |s: &mut std::sync::mpsc::SyncSender>, sigs| { if let Err(e) = s.send(Some(sigs)) { Err(format!("Unable to send internal data: {:?}", e)) } else { diff --git a/src/python/tests/test_sketch.py b/src/python/tests/test_sketch.py index 33b85f46..e329fa0e 100644 --- a/src/python/tests/test_sketch.py +++ b/src/python/tests/test_sketch.py @@ -380,7 +380,7 @@ def test_manysketch_bad_fa_csv_2(runtmp, capfd): captured = capfd.readouterr() print(captured.err) assert "Could not load fasta files: no signatures created." in captured.err - assert "Error opening file bad2.fa: ParseError" in captured.err + assert "Error building signatures from file: bad2.fa" in captured.err def test_manysketch_bad_fa_csv_3(runtmp, capfd): @@ -453,35 +453,35 @@ def test_manysketch_bad_param_str_moltype(runtmp, capfd): captured = capfd.readouterr() print(captured.err) assert ( - "Error parsing params string: No moltype provided in params string k=31,scaled=100" + "Error parsing params string 'k=31,scaled=100': No moltype provided" in captured.err ) assert "Failed to parse params string" in captured.err -def test_manysketch_bad_param_str_ksize(runtmp, capfd): - # no ksize provided in param str - fa_csv = runtmp.output("db-fa.txt") +# def test_manysketch_bad_param_str_ksize(runtmp, capfd): +# # no ksize provided in param str +# fa_csv = runtmp.output("db-fa.txt") - fa1 = get_test_data("short.fa") - fa2 = get_test_data("short2.fa") - fa3 = get_test_data("short3.fa") +# fa1 = get_test_data("short.fa") +# fa2 = get_test_data("short2.fa") +# fa3 = get_test_data("short3.fa") - make_assembly_csv(fa_csv, [fa1, fa2, fa3]) - output = runtmp.output("out.zip") +# make_assembly_csv(fa_csv, [fa1, fa2, fa3]) +# output = runtmp.output("out.zip") - with pytest.raises(utils.SourmashCommandFailed): - runtmp.sourmash( - "scripts", "manysketch", fa_csv, "-o", output, "-p", "dna,scaled=100" - ) +# with pytest.raises(utils.SourmashCommandFailed): +# runtmp.sourmash( +# "scripts", "manysketch", fa_csv, "-o", output, "-p", "dna,scaled=100" +# ) - captured = capfd.readouterr() - print(captured.err) - assert ( - "Error parsing params string: No ksizes provided in params string dna,scaled=100" - in captured.err - ) - assert "Failed to parse params string" in captured.err +# captured = capfd.readouterr() +# print(captured.err) +# assert ( +# "Error parsing params string: No ksizes provided in params string dna,scaled=100" +# in captured.err +# ) +# assert "Failed to parse params string" in captured.err def test_manysketch_empty_fa_csv(runtmp, capfd): @@ -1397,3 +1397,66 @@ def test_singlesketch_multimoltype_fail(runtmp): "-p", "protein,dna,k=7", ) + + +def test_singlesketch_gzipped_output(runtmp): + """Test singlesketch with gzipped output.""" + fa1 = get_test_data("short.fa") + output = runtmp.output("short.sig.gz") + + # Run the singlesketch command + runtmp.sourmash("scripts", "singlesketch", fa1, "-o", output) + + # Check if the output exists and contains the expected data + assert os.path.exists(output) + + # Verify the file is gzipped + import gzip + + try: + with gzip.open(output, "rt") as f: + f.read(1) # Try to read a single character to ensure it's valid gzip + except gzip.BadGzipFile: + assert False, f"Output file {output} is not a valid gzipped file." + + # check the signatures + sig = sourmash.load_one_signature(output) + + assert sig.name == "short.fa" + assert sig.minhash.ksize == 31 + assert sig.minhash.is_dna + assert sig.minhash.scaled == 1000 + + # validate against sourmash sketch + output2 = runtmp.output("short2.sig") + runtmp.sourmash("sketch", "dna", fa1, "-o", output2) + sig2 = sourmash.load_one_signature(output2) + assert sig.minhash.hashes == sig2.minhash.hashes + + +def test_singlesketch_zip_output(runtmp): + """Test singlesketch with zip output.""" + fa1 = get_test_data("short.fa") + output = runtmp.output("short.zip") + + # Run the singlesketch command + runtmp.sourmash("scripts", "singlesketch", fa1, "-o", output) + + # Check if the output exists and contains the expected data + assert os.path.exists(output) + idx = sourmash.load_file_as_index(output) + sigs = list(idx.signatures()) + assert len(sigs) == 1 + print(sigs) + sig = sigs[0] + + assert sig.name == "short.fa" + assert sig.minhash.ksize == 31 + assert sig.minhash.is_dna + assert sig.minhash.scaled == 1000 + + # validate against sourmash sketch + output2 = runtmp.output("short2.sig") + runtmp.sourmash("sketch", "dna", fa1, "-o", output2) + sig2 = sourmash.load_one_signature(output2) + assert sig.minhash.hashes == sig2.minhash.hashes diff --git a/src/singlesketch.rs b/src/singlesketch.rs index 80226d37..3e50a235 100644 --- a/src/singlesketch.rs +++ b/src/singlesketch.rs @@ -1,9 +1,5 @@ -use crate::utils::parse_params_str; +use crate::utils::buildutils::BuildCollection; use anyhow::{bail, Result}; -use camino::Utf8Path as Path; -use needletail::{parse_fastx_file, parse_fastx_reader}; -use std::fs::File; -use std::io::{self, BufWriter, Write}; pub fn singlesketch( input_filename: String, @@ -12,84 +8,34 @@ pub fn singlesketch( output: String, name: String, ) -> Result<()> { - // Parse parameter string into params_vec - let param_result = parse_params_str(param_str.clone()); - let params_vec = match param_result { - Ok(params) => params, + // parse params --> signature templates + let sig_template_result = BuildCollection::from_param_str(param_str.as_str()); + let mut sigs = match sig_template_result { + Ok(sigs) => sigs, Err(e) => { - eprintln!("Error parsing params string: {}", e); - bail!("Failed to parse params string"); + bail!("Failed to parse params string: {}", e); } }; let input_moltype = input_moltype.to_ascii_lowercase(); // Build signature templates based on parsed parameters and detected moltype - let mut sigs = crate::manysketch::build_siginfo(¶ms_vec, input_moltype.as_str()); - if sigs.is_empty() { bail!("No signatures to build for the given parameters."); } - // Open FASTA file reader - let mut reader = if input_filename == "-" { - let stdin = std::io::stdin(); - parse_fastx_reader(stdin)? - } else { - parse_fastx_file(&input_filename)? - }; - - // Counter for the number of sequences processed (u64) - let mut sequence_count: u64 = 0; - - // Parse FASTA and add to signature - while let Some(record_result) = reader.next() { - match record_result { - Ok(record) => { - sigs.iter_mut().for_each(|sig| { - if input_moltype == "protein" { - sig.add_protein(&record.seq()) - .expect("Failed to add protein"); - } else { - sig.add_sequence(&record.seq(), true) - .expect("Failed to add sequence"); - } - }); - sequence_count += 1; - } - Err(err) => eprintln!("Error while processing record: {:?}", err), - } - } - - // Set name and filename for signatures - sigs.iter_mut().for_each(|sig| { - sig.set_name(&name); // Use the provided name - sig.set_filename(&input_filename); - }); - - // Check if the output is stdout or a file - if output == "-" { - // Write signatures to stdout - let stdout = io::stdout(); - let mut handle = stdout.lock(); - serde_json::to_writer(&mut handle, &sigs)?; - handle.flush()?; - } else { - // Write signatures to output file - let outpath = Path::new(&output); - let file = File::create(outpath)?; - let mut writer = BufWriter::new(file); - - // Write in JSON format - serde_json::to_writer(&mut writer, &sigs)?; - } + let sequence_count = + sigs.build_sigs_from_file_or_stdin(&input_moltype, name, input_filename.clone())?; eprintln!( "calculated {} signatures for {} sequences in {}", - sigs.len(), + sigs.size(), sequence_count, input_filename ); + // Write signatures to stdout or output file + sigs.write_sigs(&output)?; + Ok(()) } diff --git a/src/utils/buildutils.rs b/src/utils/buildutils.rs new file mode 100644 index 00000000..fca153a7 --- /dev/null +++ b/src/utils/buildutils.rs @@ -0,0 +1,1520 @@ +//! sketching utilities + +use anyhow::{anyhow, Context, Result}; +use camino::Utf8PathBuf; +use getset::{Getters, Setters}; +use needletail::parser::SequenceRecord; +use needletail::{parse_fastx_file, parse_fastx_reader}; +use rayon::iter::IndexedParallelIterator; +use rayon::iter::IntoParallelRefMutIterator; +use rayon::prelude::ParallelIterator; +use serde::Serialize; +use sourmash::cmd::ComputeParameters; +use sourmash::encodings::{HashFunctions, Idx}; +use sourmash::errors::SourmashError; +use sourmash::manifest::Record; +use sourmash::selection::Selection; +use sourmash::signature::Signature; +use std::collections::HashMap; +use std::collections::HashSet; +use std::fmt::Display; +use std::fs::File; +use std::hash::{Hash, Hasher}; +use std::io::{Cursor, Seek, Write}; +use std::num::ParseIntError; +use std::ops::Index; +use std::str::FromStr; +use zip::write::{FileOptions, ZipWriter}; +use zip::CompressionMethod; + +#[derive(Default, Debug, Clone)] +pub struct MultiSelection { + pub selections: Vec, +} + +impl MultiSelection { + /// Create a `MultiSelection` from a single `Selection` + pub fn new(selection: Selection) -> Self { + MultiSelection { + selections: vec![selection], + } + } + + pub fn from_moltypes(moltypes: Vec<&str>) -> Result { + let selections: Result, SourmashError> = moltypes + .into_iter() + .map(|moltype_str| { + let moltype = HashFunctions::try_from(moltype_str)?; + let mut new_selection = Selection::default(); // Create a default Selection + new_selection.set_moltype(moltype); // Set the moltype + Ok(new_selection) + }) + .collect(); + + Ok(MultiSelection { + selections: selections?, + }) + } + + pub fn from_input_moltype(input_moltype: &str) -> Result { + // currently we don't allow translation. Will need to change this when we do. + // is there a better way to do this? + let mut moltypes = vec!["DNA"]; + if input_moltype == "protein" { + moltypes = vec!["protein", "dayhoff", "hp"]; + } + let selections: Result, SourmashError> = moltypes + .into_iter() + .map(|moltype_str| { + let moltype = HashFunctions::try_from(moltype_str)?; + let mut new_selection = Selection::default(); // Create a default Selection + new_selection.set_moltype(moltype); // Set the moltype + Ok(new_selection) + }) + .collect(); + + Ok(MultiSelection { + selections: selections?, + }) + } +} + +pub trait MultiSelect { + fn select(self, multi_selection: &MultiSelection) -> Result + where + Self: Sized; +} + +#[derive(Debug, Clone, Getters, Setters, Serialize)] +pub struct BuildRecord { + // fields are ordered the same as Record to allow serialization to manifest + // required fields are currently immutable once set + #[getset(get = "pub", set = "pub")] + internal_location: Option, + + #[getset(get = "pub", set = "pub")] + md5: Option, + + #[getset(get = "pub", set = "pub")] + md5short: Option, + + #[getset(get_copy = "pub", set = "pub")] + ksize: u32, + + moltype: String, + + #[getset(get = "pub")] + num: u32, + + #[getset(get = "pub")] + scaled: u64, + + #[getset(get = "pub", set = "pub")] + n_hashes: Option, + + #[getset(get_copy = "pub", set = "pub")] + #[serde(serialize_with = "intbool")] + with_abundance: bool, + + #[getset(get = "pub", set = "pub")] + name: Option, + + #[getset(get = "pub", set = "pub")] + filename: Option, + + #[getset(get_copy = "pub")] + #[serde(skip)] + pub seed: u32, + + #[serde(skip)] + pub hashed_params: u64, + + #[serde(skip)] + pub sequence_added: bool, +} + +// from sourmash (intbool is currently private there) +fn intbool(x: &bool, s: S) -> std::result::Result +where + S: serde::Serializer, +{ + if *x { + s.serialize_i32(1) + } else { + s.serialize_i32(0) + } +} + +impl BuildRecord { + // no general default, but we have defaults for each moltype + pub fn default_dna() -> Self { + Self { + internal_location: None, + md5: None, + md5short: None, + ksize: 31, + moltype: "DNA".to_string(), + num: 0, + scaled: 1000, + n_hashes: None, + with_abundance: false, + name: None, + filename: None, + seed: 42, + hashed_params: 0, + sequence_added: false, + } + } + + pub fn default_protein() -> Self { + Self { + moltype: "protein".to_string(), + ksize: 10, + scaled: 200, + ..Self::default_dna() + } + } + + pub fn default_dayhoff() -> Self { + Self { + moltype: "dayhoff".to_string(), + ksize: 10, + scaled: 200, + ..Self::default_dna() + } + } + + pub fn default_hp() -> Self { + Self { + moltype: "hp".to_string(), + ksize: 10, + scaled: 200, + ..Self::default_dna() + } + } + + pub fn moltype(&self) -> HashFunctions { + self.moltype.as_str().try_into().unwrap() + } + + pub fn from_record(record: &Record) -> Self { + Self { + ksize: record.ksize(), + moltype: record.moltype().to_string(), + num: *record.num(), + scaled: *record.scaled() as u64, + with_abundance: record.with_abundance(), + ..Self::default_dna() // ignore remaining fields + } + } + + pub fn matches_selection(&self, selection: &Selection) -> bool { + let mut valid = true; + + if let Some(ksize) = selection.ksize() { + valid = valid && self.ksize == ksize; + } + + if let Some(moltype) = selection.moltype() { + valid = valid && self.moltype() == moltype; + } + + if let Some(abund) = selection.abund() { + valid = valid && self.with_abundance == abund; + } + + if let Some(scaled) = selection.scaled() { + // num sigs have self.scaled = 0, don't include them + valid = valid && self.scaled != 0 && self.scaled <= scaled as u64; + } + + if let Some(num) = selection.num() { + valid = valid && self.num == num; + } + + valid + } + + pub fn params(&self) -> (u32, String, bool, u32, u64) { + ( + self.ksize, + self.moltype.clone(), + self.with_abundance, + self.num, + self.scaled, + ) + } +} + +impl PartialEq for BuildRecord { + fn eq(&self, other: &Self) -> bool { + self.ksize == other.ksize + && self.moltype == other.moltype + && self.with_abundance == other.with_abundance + && self.num == other.num + && self.scaled == other.scaled + } +} + +impl Eq for BuildRecord {} + +impl Hash for BuildRecord { + fn hash(&self, state: &mut H) { + self.ksize.hash(state); + self.moltype.hash(state); + self.scaled.hash(state); + self.num.hash(state); + self.with_abundance.hash(state); + } +} + +#[derive(Debug, Default, Clone)] +pub struct BuildManifest { + records: Vec, +} + +impl BuildManifest { + pub fn new() -> Self { + BuildManifest { + records: Vec::new(), + } + } + + pub fn is_empty(&self) -> bool { + self.records.is_empty() + } + + pub fn size(&self) -> usize { + self.records.len() + } + + pub fn iter(&self) -> impl Iterator { + self.records.iter() + } + + // clear all records + pub fn clear(&mut self) { + self.records.clear(); + } + + pub fn summarize_params(&self) -> HashSet<(u32, String, bool, u32, u64)> { + self.iter().map(|record| record.params()).collect() + } + + pub fn filter_manifest(&self, other: &BuildManifest) -> Self { + // Create a HashSet of references to the `BuildRecord`s in `other` + let pairs: HashSet<_> = other.records.iter().collect(); + + // Filter `self.records` to retain only those `BuildRecord`s that are NOT in `pairs` + let records = self + .records + .iter() + .filter(|&build_record| !pairs.contains(build_record)) + .cloned() + .collect(); + + Self { records } + } + + pub fn add_record(&mut self, record: BuildRecord) { + self.records.push(record); + } + + pub fn extend_records(&mut self, other: impl IntoIterator) { + self.records.extend(other); + } + + pub fn extend_from_manifest(&mut self, other: &BuildManifest) { + self.records.extend(other.records.clone()); // Clone the records from the other manifest + } + + pub fn to_writer(&self, mut wtr: W) -> Result<()> { + // Write the manifest version as a comment + wtr.write_all(b"# SOURMASH-MANIFEST-VERSION: 1.0\n")?; + + // Use CSV writer to serialize records + let mut csv_writer = csv::Writer::from_writer(wtr); + + for record in &self.records { + csv_writer.serialize(record)?; // Serialize each BuildRecord + } + + csv_writer.flush()?; // Ensure all data is written + + Ok(()) + } + + pub fn write_manifest_to_zip( + &self, + zip: &mut ZipWriter, + options: &FileOptions<()>, + ) -> Result<()> { + zip.start_file("SOURMASH-MANIFEST.csv", *options)?; + self.to_writer(zip)?; + Ok(()) + } +} + +impl MultiSelect for BuildManifest { + fn select(self, multi_selection: &MultiSelection) -> Result { + let rows = self.records.iter().filter(|row| { + // for each row, check if it matches any of the Selection structs in MultiSelection + multi_selection + .selections + .iter() + .any(|selection| row.matches_selection(selection)) + }); + + Ok(BuildManifest { + records: rows.cloned().collect(), + }) + } +} + +impl From> for BuildManifest { + fn from(records: Vec) -> Self { + BuildManifest { records } + } +} + +impl Index for BuildManifest { + type Output = BuildRecord; + + fn index(&self, index: usize) -> &Self::Output { + &self.records[index] + } +} + +impl<'a> IntoIterator for &'a BuildManifest { + type Item = &'a BuildRecord; + type IntoIter = std::slice::Iter<'a, BuildRecord>; + + fn into_iter(self) -> Self::IntoIter { + self.records.iter() + } +} + +impl<'a> IntoIterator for &'a mut BuildManifest { + type Item = &'a mut BuildRecord; + type IntoIter = std::slice::IterMut<'a, BuildRecord>; + + fn into_iter(self) -> Self::IntoIter { + self.records.iter_mut() + } +} + +#[derive(Debug, Default, Clone)] +pub struct BuildCollection { + pub manifest: BuildManifest, + pub sigs: Vec, +} + +impl BuildCollection { + pub fn new() -> Self { + BuildCollection { + manifest: BuildManifest::new(), + sigs: Vec::new(), + } + } + + pub fn is_empty(&self) -> bool { + self.manifest.is_empty() + } + + pub fn size(&self) -> usize { + self.manifest.size() + } + + pub fn dna_size(&self) -> Result { + let multiselection = MultiSelection::from_moltypes(vec!["dna"])?; + let selected_manifest = self.manifest.clone().select(&multiselection)?; + + Ok(selected_manifest.records.len()) + } + + pub fn protein_size(&self) -> Result { + let multiselection = MultiSelection::from_moltypes(vec!["protein"])?; + let selected_manifest = self.manifest.clone().select(&multiselection)?; + + Ok(selected_manifest.records.len()) + } + + pub fn anyprotein_size(&self) -> Result { + let multiselection = MultiSelection::from_moltypes(vec!["protein", "dayhoff", "hp"])?; + let selected_manifest = self.manifest.clone().select(&multiselection)?; + + Ok(selected_manifest.records.len()) + } + + pub fn parse_ksize(value: &str) -> Result { + value + .parse::() + .map_err(|_| format!("cannot parse k='{}' as a valid integer", value)) + } + + pub fn parse_int_once( + value: &str, + field: &str, + current: &mut Option, + ) -> Result<(), String> + where + T: FromStr + Display + Copy, + { + let parsed_value = value + .parse::() + .map_err(|_| format!("cannot parse {}='{}' as a valid integer", field, value))?; + + // Check for conflicts; we don't allow multiple values for the same field. + if let Some(old_value) = *current { + return Err(format!( + "Conflicting values for '{}': {} and {}", + field, old_value, parsed_value + )); + } + + *current = Some(parsed_value); + Ok(()) + } + + pub fn parse_moltype(item: &str, current: &mut Option) -> Result { + let new_moltype = match item { + "protein" | "dna" | "dayhoff" | "hp" => item.to_string(), + _ => return Err(format!("unknown moltype '{}'", item)), + }; + + // Check for conflicts and update the moltype. + if let Some(existing) = current { + if *existing != new_moltype { + return Err(format!( + "Conflicting moltype settings in param string: '{}' and '{}'", + existing, new_moltype + )); + } + } + + *current = Some(new_moltype.clone()); + Ok(new_moltype) + } + + pub fn parse_abundance(item: &str, current: &mut Option) -> Result<(), String> { + let new_abundance = item == "abund"; + + if let Some(existing) = *current { + if existing != new_abundance { + return Err(format!( + "Conflicting abundance settings in param string: '{}'", + item + )); + } + } + + *current = Some(new_abundance); + Ok(()) + } + + pub fn summarize_params(&self) -> HashSet<(u32, String, bool, u32, u64)> { + let params: HashSet<_> = self.manifest.iter().map(|record| record.params()).collect(); + + // Print a description of the summary + eprintln!("Building {} sketch types:", params.len()); + + for (ksize, moltype, with_abundance, num, scaled) in ¶ms { + eprintln!( + " {},k={},scaled={},num={},abund={}", + moltype, ksize, scaled, num, with_abundance + ); + } + params + } + + pub fn parse_params(p_str: &str) -> Result<(BuildRecord, Vec), String> { + let mut ksizes = Vec::new(); + let mut moltype: Option = None; + let mut track_abundance: Option = None; + let mut num: Option = None; + let mut scaled: Option = None; + let mut seed: Option = None; + + for item in p_str.split(',') { + match item { + _ if item.starts_with("k=") => { + ksizes.push(Self::parse_ksize(&item[2..])?); + } + "abund" | "noabund" => { + Self::parse_abundance(item, &mut track_abundance)?; + } + "protein" | "dna" | "DNA" | "dayhoff" | "hp" => { + Self::parse_moltype(item, &mut moltype)?; + } + _ if item.starts_with("num=") => { + Self::parse_int_once(&item[4..], "num", &mut num)?; + } + _ if item.starts_with("scaled=") => { + Self::parse_int_once(&item[7..], "scaled", &mut scaled)?; + } + _ if item.starts_with("seed=") => { + Self::parse_int_once(&item[5..], "seed", &mut seed)?; + } + _ => { + return Err(format!( + "Error parsing params string '{}': Unknown component '{}'", + p_str, item + )); + } + } + } + + // Ensure that moltype was set + let moltype = moltype.ok_or_else(|| { + format!( + "Error parsing params string '{}': No moltype provided", + p_str + ) + })?; + + // Create a moltype-specific default BuildRecord or return an error if unsupported. + let mut base_record = match moltype.as_str() { + "dna" | "DNA" => BuildRecord::default_dna(), + "protein" => BuildRecord::default_protein(), + "dayhoff" => BuildRecord::default_dayhoff(), + "hp" => BuildRecord::default_hp(), + _ => { + return Err(format!( + "Error parsing params string '{}': Unsupported moltype '{}'", + p_str, moltype + )); + } + }; + + // Apply parsed values + if let Some(track_abund) = track_abundance { + base_record.with_abundance = track_abund; + } + if let Some(n) = num { + base_record.num = n; + } + if let Some(s) = scaled { + base_record.scaled = s; + } + if let Some(s) = seed { + base_record.seed = s; + } + + // Use the default ksize if none were specified. + if ksizes.is_empty() { + ksizes.push(base_record.ksize); + } + + // Ensure that num and scaled are mutually exclusive unless num is 0. + if let (Some(n), Some(_)) = (num, scaled) { + if n != 0 { + return Err(format!( + "Error parsing params string '{}': Cannot specify both 'num' (non-zero) and 'scaled' in the same parameter string", + p_str + )); + } + } + + Ok((base_record, ksizes)) + } + + pub fn from_param_str(params_str: &str) -> Result { + if params_str.trim().is_empty() { + return Err("Parameter string cannot be empty.".to_string()); + } + + let mut coll = BuildCollection::new(); + let mut seen_records = HashSet::new(); + + for p_str in params_str.split('_') { + // Use `parse_params` to get the base record and ksizes. + let (base_record, ksizes) = Self::parse_params(p_str)?; + + // Iterate over each ksize and add a signature to the collection. + for k in ksizes { + let mut record = base_record.clone(); + record.ksize = k; + + // Check if the record is already in the set. + if seen_records.insert(record.clone()) { + // Add the record and its associated signature to the collection. + // coll.add_template_sig_from_record(&record, &record.moltype); + coll.add_template_sig_from_record(&record); + } + } + } + Ok(coll) + } + + pub fn from_manifest(manifest: &BuildManifest) -> Self { + let mut collection = BuildCollection::new(); + + // Iterate over each `BuildRecord` in the provided `BuildManifest`. + for record in &manifest.records { + // Add a signature to the collection using the `BuildRecord` and `input_moltype`. + collection.add_template_sig_from_record(record); + } + + collection + } + + pub fn add_template_sig_from_record(&mut self, record: &BuildRecord) { + // Adjust ksize for protein, dayhoff, or hp, which require tripling the k-mer size. + let adjusted_ksize = match record.moltype.as_str() { + "protein" | "dayhoff" | "hp" => record.ksize * 3, + _ => record.ksize, + }; + + // Construct ComputeParameters. + let cp = ComputeParameters::builder() + .ksizes(vec![adjusted_ksize]) + .scaled(record.scaled as u32) + .protein(record.moltype == "protein") + .dna(record.moltype == "DNA") + .dayhoff(record.moltype == "dayhoff") + .hp(record.moltype == "hp") + .num_hashes(record.num) + .track_abundance(record.with_abundance) + .build(); + + // Create a Signature from the ComputeParameters. + let sig = Signature::from_params(&cp); + + // Clone the `BuildRecord` and use it directly. + let template_record = record.clone(); + + // Add the record and signature to the collection. + self.manifest.records.push(template_record); + self.sigs.push(sig); + } + + pub fn filter_manifest(&mut self, other: &BuildManifest) { + self.manifest = self.manifest.filter_manifest(other) + } + + pub fn filter_by_manifest(&mut self, other: &BuildManifest) { + // Create a HashSet for efficient filtering based on the `BuildRecord`s in `other`. + let other_records: HashSet<_> = other.records.iter().collect(); + + // Retain only the records that are not in `other_records`, filtering in place. + let mut sig_index = 0; + self.manifest.records.retain(|record| { + let keep = !other_records.contains(record); + if !keep { + // Remove the corresponding signature at the same index. + self.sigs.remove(sig_index); + } else { + sig_index += 1; // Only increment if we keep the record and signature. + } + keep + }); + } + + // filter template signatures that had no sequence added + // suggested use right before writing signatures + pub fn filter_empty(&mut self) { + let mut sig_index = 0; + + self.manifest.records.retain(|record| { + // Keep only records where `sequence_added` is `true`. + let keep = record.sequence_added; + + if !keep { + // Remove the corresponding signature at the same index if the record is not kept. + self.sigs.remove(sig_index); + } else { + sig_index += 1; // Only increment if we keep the record and signature. + } + + keep + }); + } + + pub fn filter(&mut self, params_set: &HashSet) { + let mut index = 0; + while index < self.manifest.records.len() { + let record = &self.manifest.records[index]; + + // filter records with matching Params + if params_set.contains(&record.hashed_params) { + self.manifest.records.remove(index); + self.sigs.remove(index); + } else { + index += 1; + } + } + } + + pub fn iter(&self) -> impl Iterator { + self.manifest.iter().enumerate().map(|(i, r)| (i as Idx, r)) + } + + pub fn record_for_dataset(&self, dataset_id: Idx) -> Result<&BuildRecord> { + Ok(&self.manifest[dataset_id as usize]) + } + + pub fn sigs_iter_mut(&mut self) -> impl Iterator { + self.sigs.iter_mut() + } + + pub fn iter_mut(&mut self) -> impl Iterator { + // zip together mutable iterators over records and sigs + self.manifest.records.iter_mut().zip(self.sigs.iter_mut()) + } + + pub fn par_iter_mut( + &mut self, + ) -> impl ParallelIterator { + self.manifest + .records + .par_iter_mut() // Parallel mutable iterator over records + .zip(self.sigs.par_iter_mut()) // Parallel mutable iterator over sigs + } + + fn build_sigs_from_record( + &mut self, + input_moltype: &str, + record: &SequenceRecord, + ) -> Result<()> { + self.par_iter_mut().try_for_each(|(rec, sig)| { + if input_moltype == "protein" + && (rec.moltype() == HashFunctions::Murmur64Protein + || rec.moltype() == HashFunctions::Murmur64Dayhoff + || rec.moltype() == HashFunctions::Murmur64Hp) + { + sig.add_protein(&record.seq()) + .context("Failed to add protein")?; + if !rec.sequence_added { + rec.sequence_added = true; + } + } else if (input_moltype == "DNA" || input_moltype == "dna") + && rec.moltype() == HashFunctions::Murmur64Dna + { + sig.add_sequence(&record.seq(), true) + .context("Failed to add sequence")?; + if !rec.sequence_added { + rec.sequence_added = true; + } + } + Ok(()) + }) + } + + pub fn build_sigs_from_data( + &mut self, + data: Vec, + input_moltype: &str, + name: String, + filename: String, + ) -> Result<()> { + let cursor = Cursor::new(data); + let mut fastx_reader = + parse_fastx_reader(cursor).context("Failed to parse FASTA/FASTQ data")?; + + // Iterate over FASTA records and add sequences/proteins to sigs + while let Some(record) = fastx_reader.next() { + let record = record.context("Failed to read record")?; + self.build_sigs_from_record(input_moltype, &record)?; + } + + // After processing sequences, update sig, record information + self.update_info(name, filename); + + Ok(()) + } + + pub fn build_sigs_from_file_or_stdin( + &mut self, + input_moltype: &str, // "protein" or "DNA" + name: String, + filename: String, + ) -> Result { + // Create a FASTX reader from the file or stdin + let mut fastx_reader = if filename == "-" { + let stdin = std::io::stdin(); + parse_fastx_reader(stdin).context("Failed to parse FASTA/FASTQ data from stdin")? + } else { + parse_fastx_file(&filename).context("Failed to open file for FASTA/FASTQ data")? + }; + + // Counter for the number of records processed + let mut record_count: u64 = 0; + + // Parse records and add sequences to signatures + while let Some(record_result) = fastx_reader.next() { + let record = record_result.context("Failed to read a record from input")?; + + self.build_sigs_from_record(input_moltype, &record)?; + + record_count += 1; + } + + // Update signature and record metadata + self.update_info(name, filename); + + // Return the count of records parsed + Ok(record_count) + } + + pub fn build_singleton_sigs( + &mut self, + record: SequenceRecord, + input_moltype: &str, // (protein/dna); todo - use hashfns? + filename: String, + ) -> Result<()> { + self.build_sigs_from_record(input_moltype, &record)?; + // After processing sequences, update sig, record information + let record_name = std::str::from_utf8(record.id()) + .expect("could not get record id") + .to_string(); + self.update_info(record_name, filename); + + Ok(()) + } + + pub fn update_info(&mut self, name: String, filename: String) { + // update the records to reflect information the signature; + for (record, sig) in self.iter_mut() { + if record.sequence_added { + // update signature name, filename + sig.set_name(name.as_str()); + sig.set_filename(filename.as_str()); + + // update record: set name, filename, md5sum, n_hashes + record.set_name(Some(name.clone())); + record.set_filename(Some(filename.clone())); + record.set_md5(Some(sig.md5sum())); + record.set_md5short(Some(sig.md5sum()[0..8].into())); + record.set_n_hashes(Some(sig.size())); + + // note, this needs to be set when writing sigs (not here) + // record.set_internal_location("") + } + } + } + + pub fn write_sigs(&mut self, output: &str) -> Result<()> { + let gzip = output.ends_with(".gz"); + if output == "-" { + // Write to stdout + let stdout = std::io::stdout(); + let mut handle = stdout.lock(); + self.write_sigs_as_json(&mut handle, gzip) + .context("Failed to write signatures to stdout")?; + handle.flush().context("Failed to flush stdout")?; + } else if output.ends_with(".zip") { + let options = FileOptions::default() + .compression_method(CompressionMethod::Stored) + .unix_permissions(0o644) + .large_file(true); + // Write to a zip file + let file = + File::create(output).context(format!("Failed to create file: {}", output))?; + let mut zip = ZipWriter::new(file); + let mut md5sum_occurrences: HashMap = HashMap::new(); + self.write_sigs_to_zip(&mut zip, &mut md5sum_occurrences, &options) + .context(format!( + "Failed to write signatures to zip file: {}", + output + ))?; + println!("Writing manifest"); + self.manifest.write_manifest_to_zip(&mut zip, &options)?; + zip.finish()?; + } else { + // Write JSON to output file + let file = + File::create(output).context(format!("Failed to create file: {}", output))?; + let mut writer = std::io::BufWriter::new(file); + self.write_sigs_as_json(&mut writer, gzip) + .context(format!("Failed to write signatures to file: {}", output))?; + } + Ok(()) + } + + pub fn write_sigs_to_zip( + &mut self, // need mutable to update records + zip: &mut ZipWriter, + md5sum_occurrences: &mut HashMap, + options: &FileOptions<()>, + ) -> Result<()> { + // iterate over both records and signatures + for (record, sig) in self.iter_mut() { + // skip any empty sig templates (no sequence added) + // TODO --> test that this is working + if !record.sequence_added { + continue; + } + let md5sum_str = sig.md5sum(); + let count = md5sum_occurrences.entry(md5sum_str.clone()).or_insert(0); + *count += 1; + + // Generate the signature filename + let sig_filename = if *count > 1 { + format!("signatures/{}_{}.sig.gz", md5sum_str, count) + } else { + format!("signatures/{}.sig.gz", md5sum_str) + }; + + // Update record's internal_location with the signature filename + record.internal_location = Some(sig_filename.clone().into()); + + // Serialize signature to JSON + let wrapped_sig = vec![sig.clone()]; + let json_bytes = serde_json::to_vec(&wrapped_sig) + .map_err(|e| anyhow!("Error serializing signature: {}", e))?; + + // Gzip compress the JSON bytes + let gzipped_buffer = { + let mut buffer = Cursor::new(Vec::new()); + { + let mut gz_writer = niffler::get_writer( + Box::new(&mut buffer), + niffler::compression::Format::Gzip, + niffler::compression::Level::Nine, + )?; + gz_writer.write_all(&json_bytes)?; + } + buffer.into_inner() + }; + + zip.start_file(sig_filename, options.clone())?; + zip.write_all(&gzipped_buffer) + .map_err(|e| anyhow!("Error writing zip entry for signature: {}", e))?; + } + + Ok(()) + } + + pub fn write_sigs_as_json( + &mut self, // mutable to update records if needed + writer: &mut W, + gzip: bool, + ) -> Result<()> { + // Create a vector to store all signatures + let mut all_signatures = Vec::new(); + + // Iterate over both records and signatures + for (record, sig) in self.iter_mut() { + // Skip any empty sig templates (no sequence added) + if !record.sequence_added { + continue; + } + + // Add the signature to the collection for JSON serialization + all_signatures.push(sig.clone()); + } + + // Serialize all signatures to JSON + let json_bytes = serde_json::to_vec(&all_signatures) + .map_err(|e| anyhow!("Error serializing signatures to JSON: {}", e))?; + + if gzip { + // Gzip compress the JSON bytes + let mut gz_writer = niffler::get_writer( + Box::new(writer), + niffler::compression::Format::Gzip, + niffler::compression::Level::Nine, + )?; + gz_writer.write_all(&json_bytes)?; + // gz_writer.finish()?; + } else { + // Write uncompressed JSON to the writer + writer.write_all(&json_bytes)?; + } + + Ok(()) + } +} + +impl<'a> IntoIterator for &'a mut BuildCollection { + type Item = (&'a mut BuildRecord, &'a mut Signature); + type IntoIter = + std::iter::Zip, std::slice::IterMut<'a, Signature>>; + + fn into_iter(self) -> Self::IntoIter { + self.manifest.records.iter_mut().zip(self.sigs.iter_mut()) + } +} + +impl MultiSelect for BuildCollection { + // to do --> think through the best/most efficient way to do this + // in sourmash core, we don't need to select sigs themselves. Is this due to the way that Idx/Storage work? + fn select(mut self, multi_selection: &MultiSelection) -> Result { + // Collect indices while retaining matching records + let mut selected_indices = Vec::new(); + let mut current_index = 0; + + self.manifest.records.retain(|record| { + let keep = multi_selection + .selections + .iter() + .any(|selection| record.matches_selection(selection)); + + if keep { + selected_indices.push(current_index); // Collect the index of the retained record + } + + current_index += 1; // Move to the next index + keep // Retain the record if it matches the selection + }); + + // Retain corresponding signatures using the collected indices + let mut sig_index = 0; + self.sigs.retain(|_sig| { + let keep = selected_indices.contains(&sig_index); + sig_index += 1; + keep + }); + + Ok(self) + } +} + +#[derive(Debug, Clone)] +pub struct MultiBuildCollection { + pub collections: Vec, +} + +impl MultiBuildCollection { + pub fn new() -> Self { + MultiBuildCollection { + collections: Vec::new(), + } + } + + pub fn is_empty(&self) -> bool { + self.collections.is_empty() + } + + pub fn add_collection(&mut self, collection: &mut BuildCollection) { + self.collections.push(collection.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_valid_params_str() { + let params_str = "k=31,abund,dna"; + let result = BuildCollection::parse_params(params_str); + + assert!( + result.is_ok(), + "Expected 'k=31,abund,dna' to be valid, but got an error: {:?}", + result + ); + + let (record, ksizes) = result.unwrap(); + + // Verify that the Record, ksizes have the correct settings. + assert_eq!(record.moltype, "DNA"); + assert_eq!(record.with_abundance, true); + assert_eq!(ksizes, vec![31]); + assert_eq!(record.scaled, 1000, "Expected default scaled value of 1000"); + assert_eq!(record.num, 0, "Expected default num value of 0"); + } + + #[test] + fn test_from_param_str() { + let params_str = "k=31,abund,dna_k=21,k=31,k=51,abund_k=10,protein"; + let coll_result = BuildCollection::from_param_str(params_str); + + assert!( + coll_result.is_ok(), + "Param str '{}' is valid, but got an error: {:?}", + params_str, + coll_result + ); + + let coll = coll_result.unwrap(); + + // Ensure the BuildCollection contains the expected number of records. + // Note that "k=31,abund,dna" appears in two different parameter strings, so it should only appear once. + assert_eq!( + coll.manifest.records.len(), + 4, + "Expected 4 unique BuildRecords in the collection, but found {}", + coll.manifest.records.len() + ); + + // Define the expected BuildRecords for comparison. + let expected_records = vec![ + BuildRecord { + ksize: 31, + moltype: "DNA".to_string(), + with_abundance: true, + ..BuildRecord::default_dna() + }, + BuildRecord { + ksize: 21, + moltype: "DNA".to_string(), + with_abundance: true, + ..BuildRecord::default_dna() + }, + BuildRecord { + ksize: 51, + moltype: "DNA".to_string(), + with_abundance: true, + ..BuildRecord::default_dna() + }, + BuildRecord::default_protein(), + ]; + + // Verify that each expected BuildRecord is present in the collection. + for expected_record in expected_records { + assert!( + coll.manifest.records.contains(&expected_record), + "Expected BuildRecord with ksize: {}, moltype: {}, with_abundance: {} not found in the collection", + expected_record.ksize, + expected_record.moltype, + expected_record.with_abundance + ); + } + + // Optionally, check that the corresponding signatures are present. + assert_eq!( + coll.sigs.len(), + 4, + "Expected 4 Signatures in the collection, but found {}", + coll.sigs.len() + ); + } + + #[test] + fn test_invalid_params_str_conflicting_moltypes() { + let params_str = "k=31,abund,dna,protein"; + let result = BuildCollection::from_param_str(params_str); + + assert!( + result.is_err(), + "Expected 'k=31,abund,dna,protein' to be invalid due to conflicting moltypes, but got a successful result" + ); + + // Check if the error message contains the expected conflict text. + if let Err(e) = result { + assert!( + e.contains("Conflicting moltype settings"), + "Expected error to contain 'Conflicting moltype settings', but got: {}", + e + ); + } + } + + #[test] + fn test_unknown_component_error() { + // Test for an unknown component that should trigger an error. + let result = BuildCollection::from_param_str("k=31,notaparam"); + assert!(result.is_err(), "Expected an error but got Ok."); + assert_eq!( + result.unwrap_err(), + "unknown component 'notaparam' in params string" + ); + } + + #[test] + fn test_unknown_component_error2() { + // Test a common param string error (k=31,51 compared with valid k=31,k=51) + let result = BuildCollection::from_param_str("k=31,51,abund"); + assert!(result.is_err(), "Expected an error but got Ok."); + assert_eq!( + result.unwrap_err(), + "unknown component '51' in params string" + ); + } + + #[test] + fn test_conflicting_num_and_scaled() { + // Test for specifying both num and scaled, which should result in an error. + let result = BuildCollection::from_param_str("k=31,num=10,scaled=1000"); + assert!(result.is_err(), "Expected an error but got Ok."); + assert_eq!( + result.unwrap_err(), + "Cannot specify both 'num' (non-zero) and 'scaled' in the same parameter string" + ); + } + + #[test] + fn test_conflicting_abundance() { + // Test for providing conflicting abundance settings, which should result in an error. + let result = BuildCollection::from_param_str("k=31,abund,noabund"); + assert!(result.is_err(), "Expected an error but got Ok."); + assert_eq!( + result.unwrap_err(), + "Conflicting abundance settings in param string: 'noabund'" + ); + } + + #[test] + fn test_invalid_ksize_format() { + // Test for an invalid ksize format that should trigger an error. + let result = BuildCollection::from_param_str("k=abc"); + assert!(result.is_err(), "Expected an error but got Ok."); + assert_eq!( + result.unwrap_err(), + "cannot parse k='abc' as a valid integer" + ); + } + + #[test] + fn test_invalid_num_format() { + // Test for an invalid number format that should trigger an error. + let result = BuildCollection::from_param_str("k=31,num=abc"); + assert!(result.is_err(), "Expected an error but got Ok."); + assert_eq!( + result.unwrap_err(), + "cannot parse num='abc' as a valid integer" + ); + } + + #[test] + fn test_invalid_scaled_format() { + // Test for an invalid scaled format that should trigger an error. + let result = BuildCollection::from_param_str("k=31,scaled=abc"); + assert!(result.is_err(), "Expected an error but got Ok."); + assert_eq!( + result.unwrap_err(), + "cannot parse scaled='abc' as a valid integer" + ); + } + + #[test] + fn test_invalid_seed_format() { + // Test for an invalid seed format that should trigger an error. + let result = BuildCollection::from_param_str("k=31,seed=abc"); + assert!(result.is_err(), "Expected an error but got Ok."); + assert_eq!( + result.unwrap_err(), + "cannot parse seed='abc' as a valid integer" + ); + } + + #[test] + fn test_repeated_values() { + // repeated scaled + let result = BuildCollection::from_param_str("k=31,scaled=1,scaled=1000"); + assert!(result.is_err(), "Expected an error but got Ok."); + assert_eq!( + result.unwrap_err(), + "Conflicting values for 'scaled': 1 and 1000" + ); + + // repeated num + let result = BuildCollection::from_param_str("k=31,num=1,num=1000"); + assert!(result.is_err(), "Expected an error but got Ok."); + assert_eq!( + result.unwrap_err(), + "Conflicting values for 'num': 1 and 1000" + ); + + // repeated seed + let result = BuildCollection::from_param_str("k=31,seed=1,seed=42"); + assert!(result.is_err(), "Expected an error but got Ok."); + assert_eq!( + result.unwrap_err(), + "Conflicting values for 'seed': 1 and 42" + ); + } + + #[test] + fn test_missing_ksize() { + // Test for a missing ksize, using default should not result in an error. + let result = BuildCollection::from_param_str("abund"); + assert!(result.is_ok(), "Expected Ok but got an error."); + } + + #[test] + fn test_repeated_ksize() { + // Repeated ksize settings should not trigger an error since it is valid to have multiple ksizes. + let result = BuildCollection::from_param_str("k=31,k=21"); + assert!(result.is_ok(), "Expected Ok but got an error."); + } + + #[test] + fn test_empty_string() { + // Test for an empty parameter string, which should now result in an error. + let result = BuildCollection::from_param_str(""); + assert!(result.is_err(), "Expected an error but got Ok."); + assert_eq!(result.unwrap_err(), "Parameter string cannot be empty."); + } + + #[test] + fn test_filter_by_manifest_with_matching_records() { + // Create a BuildCollection with some records and signatures. + + let rec1 = BuildRecord::default_dna(); + let rec2 = BuildRecord { + ksize: 21, + moltype: "DNA".to_string(), + scaled: 1000, + ..BuildRecord::default_dna() + }; + let rec3 = BuildRecord { + ksize: 31, + moltype: "DNA".to_string(), + scaled: 1000, + with_abundance: true, + ..BuildRecord::default_dna() + }; + + let bmanifest = BuildManifest { + records: vec![rec1.clone(), rec2.clone(), rec3.clone()], + }; + // let mut dna_build_collection = BuildCollection::from_manifest(&bmanifest, "DNA"); + let mut dna_build_collection = BuildCollection::from_manifest(&bmanifest); + + // Create a BuildManifest with records to filter out. + let filter_manifest = BuildManifest { + records: vec![rec1], + }; + + // Apply the filter. + dna_build_collection.filter_by_manifest(&filter_manifest); + + // check that the default DNA sig remains + assert_eq!(dna_build_collection.manifest.size(), 2); + + let remaining_records = &dna_build_collection.manifest.records; + + assert!(remaining_records.contains(&rec2)); + assert!(remaining_records.contains(&rec3)); + } + + #[test] + fn test_add_template_sig_from_record() { + // Create a BuildCollection. + let mut build_collection = BuildCollection::new(); + + // Create a DNA BuildRecord. + let dna_record = BuildRecord { + ksize: 31, + moltype: "DNA".to_string(), + scaled: 1000, + with_abundance: true, + ..BuildRecord::default_dna() + }; + + // Add the DNA record to the collection with a matching moltype. + // build_collection.add_template_sig_from_record(&dna_record, "DNA"); + build_collection.add_template_sig_from_record(&dna_record); + + // Verify that the record was added. + assert_eq!(build_collection.manifest.records.len(), 1); + assert_eq!(build_collection.sigs.len(), 1); + + let added_record = &build_collection.manifest.records[0]; + assert_eq!(added_record.moltype, "DNA"); + assert_eq!(added_record.ksize, 31); + assert_eq!(added_record.with_abundance, true); + + // Create a protein BuildRecord. + let protein_record = BuildRecord { + ksize: 10, + moltype: "protein".to_string(), + scaled: 200, + with_abundance: false, + ..BuildRecord::default_dna() + }; + + // Add the protein record to the collection with a matching moltype. + // build_collection.add_template_sig_from_record(&protein_record, "protein"); + build_collection.add_template_sig_from_record(&protein_record); + + // Verify that the protein record was added and ksize adjusted. + assert_eq!(build_collection.manifest.records.len(), 2); + assert_eq!(build_collection.sigs.len(), 2); + + let added_protein_record = &build_collection.manifest.records[1]; + assert_eq!(added_protein_record.moltype, "protein"); + assert_eq!(added_protein_record.ksize, 10); + assert_eq!(added_protein_record.with_abundance, false); + + // Create a BuildRecord with a non-matching moltype. + let non_matching_record = BuildRecord { + ksize: 10, + moltype: "dayhoff".to_string(), + scaled: 200, + with_abundance: true, + ..BuildRecord::default_dna() + }; + + // Attempt to add the non-matching record with "DNA" as input moltype. + // this is because we currently don't allow translation + // build_collection.add_template_sig_from_record(&non_matching_record, "DNA"); + + // Verify that the non-matching record was not added. + // assert_eq!(build_collection.manifest.records.len(), 2); + // assert_eq!(build_collection.sigs.len(), 2); + + // Add the same non-matching record with a matching input moltype. + build_collection.add_template_sig_from_record(&non_matching_record); + + // Verify that the record was added. + assert_eq!(build_collection.manifest.records.len(), 3); + assert_eq!(build_collection.sigs.len(), 3); + + let added_dayhoff_record = &build_collection.manifest.records[2]; + assert_eq!(added_dayhoff_record.moltype, "dayhoff"); + assert_eq!(added_dayhoff_record.ksize, 10); + assert_eq!(added_dayhoff_record.with_abundance, true); + } + + #[test] + fn test_filter_empty() { + // Create a parameter string that generates BuildRecords with different `sequence_added` values. + let params_str = "k=31,abund,dna_k=21,protein_k=10,abund"; + + // Use `from_param_str` to build a `BuildCollection`. + let mut build_collection = BuildCollection::from_param_str(params_str) + .expect("Failed to build BuildCollection from params_str"); + + // Manually set `sequence_added` for each record to simulate different conditions. + build_collection.manifest.records[0].sequence_added = true; // Keep this record. + build_collection.manifest.records[1].sequence_added = false; // This record should be removed. + build_collection.manifest.records[2].sequence_added = true; // Keep this record. + + // Check initial sizes before filtering. + assert_eq!( + build_collection.manifest.records.len(), + 3, + "Expected 3 records before filtering, but found {}", + build_collection.manifest.records.len() + ); + assert_eq!( + build_collection.sigs.len(), + 3, + "Expected 3 signatures before filtering, but found {}", + build_collection.sigs.len() + ); + + // Apply the `filter_empty` method. + build_collection.filter_empty(); + + // After filtering, only the records with `sequence_added == true` should remain. + assert_eq!( + build_collection.manifest.records.len(), + 2, + "Expected 2 records after filtering, but found {}", + build_collection.manifest.records.len() + ); + + // Check that the signatures also match the remaining records. + assert_eq!( + build_collection.sigs.len(), + 2, + "Expected 2 signatures after filtering, but found {}", + build_collection.sigs.len() + ); + + // Verify that the remaining records have `sequence_added == true`. + assert!( + build_collection + .manifest + .records + .iter() + .all(|rec| rec.sequence_added), + "All remaining records should have `sequence_added == true`" + ); + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 0473e3de..20472d8e 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -5,7 +5,7 @@ use sourmash::encodings::HashFunctions; use sourmash::selection::Select; use sourmash::ScaledType; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use camino::Utf8Path as Path; use camino::Utf8PathBuf as PathBuf; use csv::Writer; @@ -19,6 +19,8 @@ use std::io::{BufWriter, Write}; use std::panic; use std::sync::atomic; use std::sync::atomic::AtomicUsize; +use std::sync::mpsc::Receiver; +use std::thread::JoinHandle; use zip::write::{ExtendedFileOptions, FileOptions, ZipWriter}; use zip::CompressionMethod; @@ -34,6 +36,9 @@ use std::hash::{Hash, Hasher}; pub mod multicollection; use multicollection::MultiCollection; +pub mod buildutils; +use buildutils::{BuildCollection, BuildManifest, MultiBuildCollection}; + /// Structure to hold overlap information from comparisons. pub struct PrefetchResult { pub name: String, @@ -1183,14 +1188,13 @@ impl Hash for Params { } } -pub fn sigwriter( - recv: std::sync::mpsc::Receiver>>, +pub fn zipwriter_handle( + recv: Receiver>, output: String, -) -> std::thread::JoinHandle> { +) -> JoinHandle> { std::thread::spawn(move || -> Result<()> { - // cast output as PathBuf + // Convert output to PathBuf let outpath: PathBuf = output.into(); - let file_writer = open_output_file(&outpath); let options = FileOptions::default() @@ -1199,34 +1203,33 @@ pub fn sigwriter( .large_file(true); let mut zip = ZipWriter::new(file_writer); - let mut manifest_rows: Vec = Vec::new(); - // keep track of MD5 sum occurrences to prevent overwriting duplicates let mut md5sum_occurrences: HashMap = HashMap::new(); + let mut zip_manifest = BuildManifest::new(); - // Process all incoming signatures + // Process each incoming Option while let Ok(message) = recv.recv() { match message { - Some(sigs) => { - for sig in sigs.iter() { - let md5sum_str = sig.md5sum(); - let count = md5sum_occurrences.entry(md5sum_str.clone()).or_insert(0); - *count += 1; - let sig_filename = if *count > 1 { - format!("signatures/{}_{}.sig.gz", md5sum_str, count) - } else { - format!("signatures/{}.sig.gz", md5sum_str) - }; - write_signature(sig, &mut zip, options.clone(), &sig_filename); - let records: Vec = Record::from_sig(sig, sig_filename.as_str()); - manifest_rows.extend(records); + Some(mut build_collection) => { + // Use BuildCollection's method to write signatures to the zip file + match build_collection.write_sigs_to_zip( + &mut zip, + &mut md5sum_occurrences, + &options, + ) { + Ok(_) => { + zip_manifest.extend_from_manifest(&build_collection.manifest); + } + Err(e) => { + let error = e.context("Error processing signature in BuildCollection"); + eprintln!("Error: {}", error); + return Err(error); + } } } None => { - // Write the manifest and finish the ZIP file + // Finalize and write the manifest when None is received println!("Writing manifest"); - zip.start_file("SOURMASH-MANIFEST.csv", options)?; - let manifest: Manifest = manifest_rows.clone().into(); - manifest.to_writer(&mut zip)?; + zip_manifest.write_manifest_to_zip(&mut zip, &options)?; zip.finish()?; break; } @@ -1254,114 +1257,3 @@ pub fn csvwriter_thread( writer.flush().expect("Failed to flush writer."); }) } - -pub fn write_signature( - sig: &Signature, - zip: &mut zip::ZipWriter>, - zip_options: zip::write::FileOptions, - sig_filename: &str, -) { - let wrapped_sig = vec![sig]; - let json_bytes = serde_json::to_vec(&wrapped_sig).unwrap(); - - let gzipped_buffer = { - let mut buffer = std::io::Cursor::new(Vec::new()); - { - let mut gz_writer = niffler::get_writer( - Box::new(&mut buffer), - niffler::compression::Format::Gzip, - niffler::compression::Level::Nine, - ) - .unwrap(); - gz_writer.write_all(&json_bytes).unwrap(); - } - buffer.into_inner() - }; - - zip.start_file(sig_filename, zip_options).unwrap(); - zip.write_all(&gzipped_buffer).unwrap(); -} - -pub fn parse_params_str(params_strs: String) -> Result, String> { - let mut unique_params: std::collections::HashSet = std::collections::HashSet::new(); - - // split params_strs by _ and iterate over each param - for p_str in params_strs.split('_').collect::>().iter() { - let items: Vec<&str> = p_str.split(',').collect(); - - let mut ksizes = Vec::new(); - let mut track_abundance = false; - let mut num = 0; - let mut scaled = 1000; - let mut seed = 42; - let mut is_dna = false; - let mut is_protein = false; - let mut is_dayhoff = false; - let mut is_hp = false; - - for item in items.iter() { - match *item { - _ if item.starts_with("k=") => { - let k_value = item[2..] - .parse() - .map_err(|_| format!("cannot parse k='{}' as a number", &item[2..]))?; - ksizes.push(k_value); - } - "abund" => track_abundance = true, - "noabund" => track_abundance = false, - _ if item.starts_with("num=") => { - num = item[4..] - .parse() - .map_err(|_| format!("cannot parse num='{}' as a number", &item[4..]))?; - } - _ if item.starts_with("scaled=") => { - scaled = item[7..] - .parse() - .map_err(|_| format!("cannot parse scaled='{}' as a number", &item[7..]))?; - } - _ if item.starts_with("seed=") => { - seed = item[5..] - .parse() - .map_err(|_| format!("cannot parse seed='{}' as a number", &item[5..]))?; - } - "protein" => { - is_protein = true; - } - "dna" => { - is_dna = true; - } - "dayhoff" => { - is_dayhoff = true; - } - "hp" => { - is_hp = true; - } - _ => return Err(format!("unknown component '{}' in params string", item)), - } - } - - if !is_dna && !is_protein && !is_dayhoff && !is_hp { - return Err(format!("No moltype provided in params string {}", p_str)); - } - if ksizes.is_empty() { - return Err(format!("No ksizes provided in params string {}", p_str)); - } - - for &k in &ksizes { - let param = Params { - ksize: k, - track_abundance, - num, - scaled, - seed, - is_protein, - is_dna, - is_dayhoff, - is_hp, - }; - unique_params.insert(param); - } - } - - Ok(unique_params.into_iter().collect()) -}