From fef279db623a90e0cdd13762856762c2492e02ca Mon Sep 17 00:00:00 2001 From: "N. Tessa Pierce-Ward" Date: Mon, 30 Sep 2024 17:46:27 -0700 Subject: [PATCH] test param hashing; turns out it was fine - test sigs were abund! --- src/directsketch.rs | 26 ++- src/utils.rs | 358 +++++++++++++++++++++++++++++++++++++++-- tests/test_gbsketch.py | 64 +++++++- 3 files changed, 414 insertions(+), 34 deletions(-) diff --git a/src/directsketch.rs b/src/directsketch.rs index 0e1f6db..2fdb220 100644 --- a/src/directsketch.rs +++ b/src/directsketch.rs @@ -17,9 +17,9 @@ use tokio_util::compat::Compat; use pyo3::prelude::*; use crate::utils::{ - build_params_hashset_from_records, load_accession_info, load_gbassembly_info, parse_params_str, - AccessionData, BuildCollection, BuildManifest, GBAssemblyData, GenBankFileType, InputMolType, - MultiBuildCollection, MultiCollection, + load_accession_info, load_gbassembly_info, parse_params_str, AccessionData, BuildCollection, + BuildManifest, GBAssemblyData, GenBankFileType, InputMolType, MultiBuildCollection, + MultiCollection, }; use reqwest::Url; @@ -844,24 +844,17 @@ pub async fn gbsketch( // Check if there are any existing batches to process if !existing_batches.is_empty() { let existing_sigs = MultiCollection::from_zipfiles(&existing_batches)?; - - // Iterate through existing sigs and produce HashMap of fasta filename or name:: params hashsets - for (_collection, _idx, record) in existing_sigs.iter() { - // Get the record's name or fasta filename - let record_name = record.name().clone(); - - // Build the params hashset for this record - let params_hashset = build_params_hashset_from_records(&[record]); - - // Insert into the HashMap - name_params_map.insert(record_name, params_hashset); - } + name_params_map = existing_sigs.build_params_hashmap(); batch_index = max_existing_batch_index + 1; + eprintln!( + "Found {} existing zip batches. Starting new sig writing at batch {}", + max_existing_batch_index, batch_index + ); filter = true; } else { // No existing batches, skipping signature filtering - eprintln!("No existing signature batches found, skipping filter step."); + eprintln!("No existing signature batches found; building all signatures."); } } @@ -963,7 +956,6 @@ pub async fn gbsketch( // filter template sigs based on existing sigs if filter { if let Some(existing_paramset) = name_params_map.get(&accinfo.name) { - eprintln!("filtering!"); // If the key exists, filter template sigs dna_sigs.filter(&existing_paramset); prot_sigs.filter(&existing_paramset); diff --git a/src/utils.rs b/src/utils.rs index 53da11a..91c9b91 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -318,6 +318,12 @@ impl Hash for Params { } impl Params { + pub fn calculate_hash(&self) -> u64 { + let mut hasher = DefaultHasher::new(); + self.hash(&mut hasher); // Use the Hash trait implementation + hasher.finish() // Return the final u64 hash value + } + pub fn from_record(record: &Record) -> Self { let moltype = record.moltype(); // Get the moltype (HashFunctions enum) @@ -335,19 +341,6 @@ impl Params { } } -pub fn build_params_hashset_from_records(records: &[&Record]) -> HashSet { - records - .iter() - .map(|record| { - let params = Params::from_record(record); - // Create a hasher to hash the Params into a u64 - let mut hasher = DefaultHasher::new(); - params.hash(&mut hasher); // Hash the Params struct - hasher.finish() // Get the hashed value (u64) - }) - .collect() -} - #[derive(Debug, Default, Clone, Getters, Setters, Serialize)] pub struct BuildRecord { // fields are ordered the same as Record to allow serialization to manifest @@ -501,6 +494,24 @@ impl BuildManifest { } } +impl<'a> IntoIterator for &'a BuildManifest { + type Item = &'a BuildRecord; + type IntoIter = std::slice::Iter<'a, BuildRecord>; + + fn into_iter(self) -> Self::IntoIter { + self.records.iter() + } +} + +impl<'a> IntoIterator for &'a mut BuildManifest { + type Item = &'a mut BuildRecord; + type IntoIter = std::slice::IterMut<'a, BuildRecord>; + + fn into_iter(self) -> Self::IntoIter { + self.records.iter_mut() + } +} + #[derive(Debug, Clone)] pub struct BuildCollection { pub manifest: BuildManifest, @@ -874,13 +885,13 @@ pub struct MultiCollection { } impl MultiCollection { - fn new(collections: Vec, contains_revindex: bool) -> Self { + fn new(collections: Vec) -> Self { Self { collections } } pub fn from_zipfile(sigpath: &Utf8PathBuf) -> Result { match Collection::from_zipfile(sigpath) { - Ok(collection) => Ok(MultiCollection::new(vec![collection], false)), + Ok(collection) => Ok(MultiCollection::new(vec![collection])), Err(_) => bail!("failed to load zipfile: '{}'", sigpath), } } @@ -896,7 +907,7 @@ impl MultiCollection { } } - Ok(MultiCollection::new(collections, false)) + Ok(MultiCollection::new(collections)) } pub fn stream_iter(&self) -> impl Stream + '_ { @@ -916,4 +927,319 @@ impl MultiCollection { .map(move |(idx, record)| (collection, idx, record)) }) } + + pub fn build_params_hashmap(&self) -> HashMap> { + let mut name_params_map = HashMap::new(); + + // Iterate over all collections in MultiCollection + for collection in &self.collections { + // Iterate over all records in the current collection + for (_, record) in collection.iter() { + // Get the record's name or fasta filename + let record_name = record.name().clone(); + + // Calculate the hash of the Params for the current record + let params_hash = Params::from_record(record).calculate_hash(); + + // If the name is already in the HashMap, extend the existing HashSet + // Otherwise, create a new HashSet and insert the hashed Params + name_params_map + .entry(record_name) + .or_insert_with(HashSet::new) // Create a new HashSet if the key doesn't exist + .insert(params_hash); // Insert the hashed Params into the HashSet + } + } + + name_params_map + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_same_params_produce_same_hash() { + let params1 = Params { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let params2 = Params { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let hash1 = params1.calculate_hash(); + let hash2 = params2.calculate_hash(); + + // Check that the hash for two identical Params is the same + assert_eq!(hash1, hash2, "Hashes for identical Params should be equal"); + } + + #[test] + fn test_different_params_produce_different_hashes() { + let params1 = Params { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let params2 = Params { + ksize: 21, // Changed ksize + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let hash1 = params1.calculate_hash(); + let hash2 = params2.calculate_hash(); + + // Check that the hash for different Params is different + assert_ne!( + hash1, hash2, + "Hashes for different Params should not be equal" + ); + } + + #[test] + fn test_consistent_hashing_across_multiple_calls() { + let params = Params { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let hash1 = params.calculate_hash(); + let hash2 = params.calculate_hash(); + + // Check that calling the hash function multiple times returns the same result + assert_eq!( + hash1, hash2, + "Hashes for the same Params should be consistent across multiple calls" + ); + } + + #[test] + fn test_params_generated_from_record() { + // load signature + build record + let mut filename = Utf8PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/GCA_000175535.1.sig.gz"); + let path = filename.clone(); + + let file = std::fs::File::open(filename).unwrap(); + let mut reader = std::io::BufReader::new(file); + let sigs = Signature::load_signatures( + &mut reader, + Some(31), + Some("DNA".try_into().unwrap()), + None, + ) + .unwrap(); + + assert_eq!(sigs.len(), 1); + + let sig = sigs.get(0).unwrap(); + let record = Record::from_sig(sig, path.as_str()); + + // create the expected Params based on the Record data + let expected_params = Params { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + // // Generate the Params from the Record using the from_record method + let generated_params = Params::from_record(&record[0]); + + // // Assert that the generated Params match the expected Params + assert_eq!( + generated_params, expected_params, + "Generated Params did not match the expected Params" + ); + + // // Calculate the hash for the expected Params + let expected_hash = expected_params.calculate_hash(); + + // // Calculate the hash for the generated Params + let generated_hash = generated_params.calculate_hash(); + + // // Assert that the hash for the generated Params matches the expected Params hash + assert_eq!( + generated_hash, expected_hash, + "Hash of generated Params did not match the hash of expected Params" + ); + } + + #[test] + fn test_filter_removes_matching_params() { + let params1 = Params { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let params2 = Params { + ksize: 21, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let params3 = Params { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 2000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let params_list = [params1.clone(), params2.clone(), params3.clone()]; + let mut build_collection = BuildCollection::from_params(¶ms_list, "DNA"); + + let mut params_set = HashSet::new(); + params_set.insert(params1.calculate_hash()); + params_set.insert(params3.calculate_hash()); + + // Call the filter method + build_collection.filter(¶ms_set); + + // Check that the records and signatures with matching params are removed + assert_eq!( + build_collection.manifest.records.len(), + 1, + "Only one record should remain after filtering" + ); + assert_eq!( + build_collection.sigs.len(), + 1, + "Only one signature should remain after filtering" + ); + + // Check that the remaining record is the one with hashed_params = 456 + let h2 = params2.calculate_hash(); + assert_eq!( + build_collection.manifest.records[0].hashed_params, h2, + "The remaining record should have hashed_params {}", + h2 + ); + } + + #[test] + fn test_build_params_hashmap() { + // read in zipfiles to build a MultiCollection + // load signature + build record + let mut filename = Utf8PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("tests/test-data/GCA_000961135.2.sig.zip"); + let path = filename.clone(); + + let mc = MultiCollection::from_zipfiles(&[path]).unwrap(); + + // Call build_params_hashmap + let name_params_map = mc.build_params_hashmap(); + + // Check that the HashMap contains the correct names + assert_eq!( + name_params_map.len(), + 1, + "There should be 1 unique names in the map" + ); + + let mut hashed_params = Vec::new(); + for (name, params_set) in name_params_map.iter() { + eprintln!("Name: {}", name); + for param_hash in params_set { + eprintln!(" Param Hash: {}", param_hash); + hashed_params.push(param_hash); + } + } + + let expected_params1 = Params { + ksize: 31, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let expected_params2 = Params { + ksize: 21, + track_abundance: true, + num: 0, + scaled: 1000, + seed: 42, + is_protein: false, + is_dayhoff: false, + is_hp: false, + is_dna: true, + }; + + let expected_hash1 = expected_params1.calculate_hash(); + let expected_hash2 = expected_params2.calculate_hash(); + + assert!( + hashed_params.contains(&&expected_hash1), + "Expected hash1 should be in the hashed_params" + ); + assert!( + hashed_params.contains(&&expected_hash2), + "Expected hash2 should be in the hashed_params" + ); + } } diff --git a/tests/test_gbsketch.py b/tests/test_gbsketch.py index ee9c969..60324f6 100644 --- a/tests/test_gbsketch.py +++ b/tests/test_gbsketch.py @@ -622,4 +622,66 @@ def test_gbsketch_simple_batched(runtmp, capfd): if sig.minhash.moltype == 'DNA': assert sig.md5sum() == ss2.md5sum() else: - assert sig.md5sum() == ss3.md5sum() \ No newline at end of file + assert sig.md5sum() == ss3.md5sum() + + +def test_gbsketch_simple_batch_restart(runtmp, capfd): + acc_csv = get_test_data('acc.csv') + output = runtmp.output('simple.zip') + failed = runtmp.output('failed.csv') + ch_fail = runtmp.output('checksum_dl_failed.csv') + + out1 = runtmp.output('simple.1.zip') + out2 = runtmp.output('simple.2.zip') + out3 = runtmp.output('simple.3.zip') + + + sig1 = get_test_data('GCA_000175535.1.sig.gz') + sig2 = get_test_data('GCA_000961135.2.sig.gz') + sig3 = get_test_data('GCA_000961135.2.protein.sig.gz') + ss1 = sourmash.load_one_signature(sig1, ksize=31) + ss2 = sourmash.load_one_signature(sig2, ksize=31) + ss3 = sourmash.load_one_signature(sig2, ksize=21) + # why does this need ksize =30 and not ksize = 10!??? + ss4 = sourmash.load_one_signature(sig3, ksize=30, select_moltype='protein') + + # first, cat sig2 into an output file that will trick gbsketch into thinking it's a prior batch + runtmp.sourmash('sig', 'cat', sig2, '-o', out1) + assert os.path.exists(out1) + + # now run. This should mean that out1 is sig2, out2 is sig1, out3 is sig3 + runtmp.sourmash('scripts', 'gbsketch', acc_csv, '-o', output, + '--failed', failed, '-r', '1', '--checksum-fail', ch_fail, + '--param-str', "dna,k=31,scaled=1000,abund", '-p', "protein,k=10,scaled=200", + '--batch-size', '1') + + assert os.path.exists(out1) + assert os.path.exists(out2) + assert os.path.exists(out3) + assert not os.path.exists(output) # for now, orig output file should be empty. + captured = capfd.readouterr() + print(captured.err) + + # we created this one with sig cat + idx = sourmash.load_file_as_index(out1) + sigs = list(idx.signatures()) + assert len(sigs) == 2 + for sig in sigs: + assert sig.name == ss2.name + assert ss2.md5sum() in [ss2.md5sum(), ss3.md5sum()] + + # these were created with gbsketch + idx = sourmash.load_file_as_index(out2) + sigs = list(idx.signatures()) + assert len(sigs) == 1 + for sig in sigs: + assert sig.name == ss1.name + assert sig.md5sum() == ss1.md5sum() + + idx = sourmash.load_file_as_index(out3) + sigs = list(idx.signatures()) + assert len(sigs) == 1 + for sig in sigs: + assert sig.name == ss4.name + assert sig.md5sum() == ss4.md5sum() + assert sig.minhash.moltype == 'protein'