diff --git a/src/core/src/encodings.rs b/src/core/src/encodings.rs index 5a215a3e9a..671f58310c 100644 --- a/src/core/src/encodings.rs +++ b/src/core/src/encodings.rs @@ -95,20 +95,165 @@ impl TryFrom<&str> for HashFunctions { } #[derive(Debug)] -pub struct ReadingFrames { - forward: [Vec; 3], - revcomp: [Vec; 3], +pub struct ReadingFrame { + fw: Vec, // Forward frame + rc: Option>, // Reverse complement (optional, not used for protein input) } +impl ReadingFrame { + /// Create a k-mer iterator for this reading frame + pub fn kmer_iter(&self, ksize: usize, seed: u64, force: bool) -> KmerIterator { + KmerIterator::new(&self.fw, self.rc.as_deref(), ksize, seed, force) + } +} + +#[derive(Debug)] +pub struct ReadingFrames(Vec); + impl ReadingFrames { - pub fn new(forward: [Vec; 3], rc: [Vec; 3]) -> Self { - ReadingFrames { - forward, - revcomp: rc, + /// Create ReadingFrames based on the input sequence, moltype, and protein flag + pub fn new(sequence: &[u8], is_protein: bool, hash_function: &HashFunctions) -> Self { + if is_protein { + // for protein input, return one forward frame + let frames = vec![ReadingFrame { + fw: sequence.to_vec(), + rc: None, + }]; + Self(frames) + } else if hash_function.dna() { + // DNA: just forward + rc + let dna_rc = revcomp(sequence); + let frames = vec![ReadingFrame { + fw: sequence.to_vec(), + rc: Some(dna_rc), + }]; + Self(frames) + } else if hash_function.protein() || hash_function.dayhoff() || hash_function.hp() { + // translation: build 6 frames + let dna_rc = revcomp(sequence); // Compute reverse complement for translation + let dayhoff = hash_function.dayhoff(); + let hp = hash_function.hp(); + Self::translate_frames(sequence, &dna_rc, dayhoff, hp) + } else if hash_function.skipm1n3() || hash_function.skipm2n3() { + // Skipmers: build 6 frames, following skip pattern + let (m, n) = if hash_function.skipm1n3() { + (1, 3) + } else { + (2, 3) + }; + Self::skipmer_frames(sequence, n, m) + } else { + panic!("Unsupported moltype: {}", hash_function); + } + } + + /// Generate translated frames + fn translate_frames(sequence: &[u8], dna_rc: &[u8], dayhoff: bool, hp: bool) -> Self { + let frames: Vec = (0..3) + .map(|frame_number| ReadingFrame { + fw: translated_frame(sequence, frame_number, dayhoff, hp), + rc: Some(translated_frame(dna_rc, frame_number, dayhoff, hp)), + }) + .collect(); + Self(frames) + } + + /// Generate skipmer frames + fn skipmer_frames(sequence: &[u8], n: usize, m: usize) -> Self { + let frames: Vec = (0..3) + .map(|frame_number| { + let fw = skipmer_frame(sequence, frame_number, n, m); + ReadingFrame { + fw: fw.clone(), + rc: Some(revcomp(&fw)), + } + }) + .collect(); + Self(frames) + } + + /// Access the frames + pub fn frames(&self) -> &Vec { + &self.0 + } +} + +pub struct KmerIterator<'a> { + fw: &'a [u8], + rc: Option<&'a [u8]>, + ksize: usize, + index: usize, + len: usize, + seed: u64, + force: bool, +} + +impl<'a> KmerIterator<'a> { + pub fn new(fw: &'a [u8], rc: Option<&'a [u8]>, ksize: usize, seed: u64, force: bool) -> Self { + Self { + fw, + rc, + ksize, + index: 0, + len: fw.len(), + seed, + force, } } } +impl<'a> Iterator for KmerIterator<'a> { + type Item = Result; + + fn next(&mut self) -> Option { + if self.index + self.ksize > self.len { + return None; // End of iteration + } + + // Forward k-mer + let kmer = &self.fw[self.index..self.index + self.ksize]; + + // Validate the k-mer + for j in self.index..self.index + self.ksize { + if !VALID[self.fw[j] as usize] { + if !self.force { + return Some(Err(Error::InvalidDNA { + message: String::from_utf8(kmer.to_vec()).unwrap(), + })); + } else { + self.index += 1; + return Some(Ok(0)); // Skip invalid k-mer + } + } + } + + // Reverse complement k-mer (if rc exists) + + // ... and then while moving the k-mer window forward for the sequence + // we move another window backwards for the RC. + // For a ksize = 3, and a sequence AGTCGT (len = 6): + // +-+---------+---------------+-------+ + // seq RC |i|i + ksize|len - ksize - i|len - i| + // AGTCGT ACGACT +-+---------+---------------+-------+ + // +-> +-> |0| 2 | 3 | 6 | + // +-> +-> |1| 3 | 2 | 5 | + // +-> +-> |2| 4 | 1 | 4 | + // +-> +-> |3| 5 | 0 | 3 | + // +-+---------+---------------+-------+ + // (leaving this table here because I had to draw to + // get the indices correctly) + let hash = if let Some(rc) = self.rc { + let krc = &rc[self.len - self.ksize - self.index..self.len - self.index]; + crate::_hash_murmur(std::cmp::min(kmer, krc), self.seed) + } else { + crate::_hash_murmur(kmer, self.seed) // Use only forward k-mer if rc is None + }; + + self.index += 1; + Some(Ok(hash)) + } +} + const COMPLEMENT: [u8; 256] = { let mut lookup = [0; 256]; lookup[b'A' as usize] = b'T'; @@ -152,30 +297,6 @@ pub fn translated_frame(sequence: &[u8], frame_number: usize, dayhoff: bool, hp: .collect() } -pub fn make_translated_frames( - sequence: &[u8], - dna_rc: &[u8], - dayhoff: bool, - hp: bool, -) -> ReadingFrames { - // Generate forward frames - let forward = [ - translated_frame(sequence, 0, dayhoff, hp), - translated_frame(sequence, 1, dayhoff, hp), - translated_frame(sequence, 2, dayhoff, hp), - ]; - - // Generate reverse complement frames - let revcomp = [ - translated_frame(dna_rc, 0, dayhoff, hp), - translated_frame(dna_rc, 1, dayhoff, hp), - translated_frame(dna_rc, 2, dayhoff, hp), - ]; - - // Return a ReadingFrames object - ReadingFrames::new(forward, revcomp) -} - fn skipmer_frame(seq: &[u8], start: usize, n: usize, m: usize) -> Vec { seq.iter() .skip(start) @@ -184,27 +305,6 @@ fn skipmer_frame(seq: &[u8], start: usize, n: usize, m: usize) -> Vec { .collect() } -pub fn make_skipmer_frames(seq: &[u8], n: usize, m: usize) -> ReadingFrames { - if m >= n { - panic!("m must be less than n"); - } - // Generate the first three forward frames - let forward: [Vec; 3] = [ - skipmer_frame(seq, 0, n, m), - skipmer_frame(seq, 1, n, m), - skipmer_frame(seq, 2, n, m), - ]; - // Generate the reverse complement frames - let reverse_complement: [Vec; 3] = [ - revcomp(&forward[0]), - revcomp(&forward[1]), - revcomp(&forward[2]), - ]; - - // Return the frames in a structured format - ReadingFrames::new(forward, reverse_complement) -} - static CODONTABLE: Lazy> = Lazy::new(|| { [ // F diff --git a/src/core/src/signature.rs b/src/core/src/signature.rs index ac580ed0a7..c0ea0c4c3c 100644 --- a/src/core/src/signature.rs +++ b/src/core/src/signature.rs @@ -15,7 +15,7 @@ use rayon::prelude::*; use serde::{Deserialize, Serialize}; use typed_builder::TypedBuilder; -use crate::encodings::{aa_to_dayhoff, aa_to_hp, revcomp, to_aa, HashFunctions, VALID}; +use crate::encodings::{HashFunctions, ReadingFrames}; use crate::prelude::*; use crate::sketch::minhash::KmerMinHash; use crate::sketch::Sketch; @@ -163,7 +163,6 @@ impl SigsTrait for Sketch { } } -// Iterator for converting sequence to hashes pub struct SeqToHashes { sequence: Vec, kmer_index: usize, @@ -174,21 +173,7 @@ pub struct SeqToHashes { hash_function: HashFunctions, seed: u64, hashes_buffer: Vec, - - dna_configured: bool, - dna_rc: Vec, - dna_ksize: usize, - dna_len: usize, - dna_last_position_check: usize, - - prot_configured: bool, - aa_seq: Vec, - translate_iter_step: usize, - - skipmer_configured: bool, - skip_m: usize, - skip_n: usize, - skip_len: usize, + reading_frames: ReadingFrames, } impl SeqToHashes { @@ -200,298 +185,79 @@ impl SeqToHashes { hash_function: HashFunctions, seed: u64, ) -> SeqToHashes { - let mut ksize: usize = k_size; - - // Divide the kmer size by 3 if protein - if is_protein || hash_function.protein() || hash_function.dayhoff() || hash_function.hp() { - ksize = k_size / 3; - } + // Adjust kmer size for protein-based hash functions + let adjusted_k_size = + if hash_function.protein() || hash_function.dayhoff() || hash_function.hp() { + k_size / 3 + } else { + k_size + }; - // By setting _max_index to 0, the iterator will return None and exit - let _max_index = if seq.len() >= ksize { - seq.len() - ksize + 1 + // Determine the maximum index for k-mer generation + let max_index = if seq.len() >= adjusted_k_size { + seq.len() - adjusted_k_size + 1 } else { 0 }; + // Initialize ReadingFrames based on the sequence and hash function + let reading_frames = + ReadingFrames::new(&seq.to_ascii_uppercase(), is_protein, &hash_function); + SeqToHashes { - // Here we convert the sequence to upper case sequence: seq.to_ascii_uppercase(), - k_size: ksize, kmer_index: 0, - max_index: _max_index, + k_size: adjusted_k_size, + max_index, force, is_protein, hash_function, seed, hashes_buffer: Vec::with_capacity(1000), - dna_configured: false, - dna_rc: Vec::with_capacity(1000), - dna_ksize: 0, - dna_len: 0, - dna_last_position_check: 0, - prot_configured: false, - aa_seq: Vec::new(), - translate_iter_step: 0, - skipmer_configured: false, - skip_m: 2, - skip_n: 3, - skip_len: 0, + reading_frames, } } + // some helper functions. If we remove, we could probably just rm + // these fields from SeqToHashes, since ReadingFrames handles this now + pub fn get_sequence(&self) -> &[u8] { + &self.sequence + } - fn validate_base(&self, base: u8, kmer: &[u8]) -> Option> { - if !VALID[base as usize] { - if !self.force { - return Some(Err(Error::InvalidDNA { - message: String::from_utf8(kmer.to_owned()).unwrap_or_default(), - })); - } else { - return Some(Ok(0)); // Skip this position if forced - } - } - None // Base is valid, so return None to continue + pub fn get_hash_function(&self) -> &HashFunctions { + &self.hash_function } -} -/* -Iterator that return a kmer hash for all modes except translate. -In translate mode: - - all the frames are processed at once and converted to hashes. - - all the hashes are stored in `hashes_buffer` - - after processing all the kmers, `translate_iter_step` is incremented - per iteration to iterate over all the indeces of the `hashes_buffer`. - - the iterator will die once `translate_iter_step` == length(hashes_buffer) -More info https://github.com/sourmash-bio/sourmash/pull/1946 -*/ + pub fn is_protein(&self) -> bool { + self.is_protein + } +} impl Iterator for SeqToHashes { type Item = Result; fn next(&mut self) -> Option { - if (self.kmer_index < self.max_index) || !self.hashes_buffer.is_empty() { - // Processing DNA or Translated DNA - if !self.is_protein { - // Setting the parameters only in the first iteration - if !self.dna_configured { - self.dna_ksize = self.k_size; - self.dna_len = self.sequence.len(); - - if self.hash_function.skipm1n3() - || self.hash_function.skipm2n3() && !self.skipmer_configured - { - if self.hash_function.skipm1n3() { - self.skip_m = 1 - }; - if self.hash_function.skipm2n3() { - self.skip_m = 2 - }; - // eqn from skipmer paper. might want to add one to dna_ksize to round up? - // to do - check if we need to enforce k = multiple of n for revcomp k-mers to work - // , or if we can keep the round up trick I'm using here. - eprintln!("setting skipmer extended length"); - self.skip_len = (self.skip_n * (((self.dna_ksize + 1) / self.skip_m) - 1)) - + self.skip_m; - eprintln!("skipmer extended length: {}", self.skip_len); - // my prior eqn - // self.skip_len = self.dna_ksize + ((self.dna_ksize + 1) / self.skip_m) - 1; // add 1 to round up rather than down - - // check that we can actually build skipmers - if self.k_size < self.skip_n { - unimplemented!() - } - self.skipmer_configured = true; - } - // have enough sequence to kmerize? - eprintln!("checking seq len"); - if self.dna_len < self.dna_ksize - || (self.hash_function.protein() && self.dna_len < self.k_size * 3) - || (self.hash_function.dayhoff() && self.dna_len < self.k_size * 3) - || (self.hash_function.hp() && self.dna_len < self.k_size * 3) - || (self.skipmer_configured && self.dna_len < self.skip_len) - { - return None; - } - // pre-calculate the reverse complement for the full sequence... - eprintln!("precalculating revcomp"); - // NOTE: Shall we precalc skipmer seq here too? + maybe translated frames? - self.dna_rc = revcomp(&self.sequence); - self.dna_configured = true; - } - - // Processing DNA - if self.hash_function.dna() { - eprintln!("processing DNA"); - let kmer = &self.sequence[self.kmer_index..self.kmer_index + self.dna_ksize]; - - // validate the bases - for j in std::cmp::max(self.kmer_index, self.dna_last_position_check) - ..self.kmer_index + self.dna_ksize - { - if !VALID[self.sequence[j] as usize] { - if !self.force { - return Some(Err(Error::InvalidDNA { - message: String::from_utf8(kmer.to_vec()).unwrap(), - })); - } else { - self.kmer_index += 1; - // Move the iterator to the next step - return Some(Ok(0)); - } - } - self.dna_last_position_check += 1; - } - - // ... and then while moving the k-mer window forward for the sequence - // we move another window backwards for the RC. - // For a ksize = 3, and a sequence AGTCGT (len = 6): - // +-+---------+---------------+-------+ - // seq RC |i|i + ksize|len - ksize - i|len - i| - // AGTCGT ACGACT +-+---------+---------------+-------+ - // +-> +-> |0| 2 | 3 | 6 | - // +-> +-> |1| 3 | 2 | 5 | - // +-> +-> |2| 4 | 1 | 4 | - // +-> +-> |3| 5 | 0 | 3 | - // +-+---------+---------------+-------+ - // (leaving this table here because I had to draw to - // get the indices correctly) - - let krc = &self.dna_rc[self.dna_len - self.dna_ksize - self.kmer_index - ..self.dna_len - self.kmer_index]; - let hash = crate::_hash_murmur(std::cmp::min(kmer, krc), self.seed); - self.kmer_index += 1; - Some(Ok(hash)) - } else if self.skipmer_configured { - eprintln!("processing skipmer"); - // Check bounds to ensure we don't exceed the sequence length - if self.kmer_index + self.skip_len > self.sequence.len() { - return None; - } + // Check if we've processed all k-mers + if self.kmer_index >= self.max_index { + return None; + } - // Build skipmer with DNA base validation - let mut kmer: Vec = Vec::with_capacity(self.dna_ksize); - for (_i, &base) in self.sequence - [self.kmer_index..self.kmer_index + self.skip_len] - .iter() - .enumerate() - .filter(|&(i, _)| i % self.skip_n < self.skip_m) - .take(self.dna_ksize) - { - // Use the validate_base method to check the base - if let Some(result) = self.validate_base(base, &kmer) { - self.kmer_index += 1; // Move to the next position if skipping is forced - return Some(result); + // Iterate over the frames + for frame in self.reading_frames.frames().iter() { + let kmer_iter = frame.kmer_iter(self.k_size, self.seed, self.force); + for result in kmer_iter { + match result { + Ok(hash) => self.hashes_buffer.push(hash), + Err(e) => { + if !self.force { + return Some(Err(e)); } - eprintln!("base {}", base); - kmer.push(base); - } - // eprintln!("skipmer kmer: {:?}", kmer); - - // Generate reverse complement skipmer - let krc: Vec = self.dna_rc[self.dna_len - self.skip_len - self.kmer_index - ..self.dna_len - self.kmer_index] - .iter() - .enumerate() - .filter(|&(i, _)| i % self.skip_n < self.skip_m) - .take(self.dna_ksize) - .map(|(_, &base)| base) - .collect(); - - let hash = crate::_hash_murmur(std::cmp::min(&kmer, &krc), self.seed); - self.kmer_index += 1; - eprintln!("built skipmer hash"); - Some(Ok(hash)) - } else if self.hashes_buffer.is_empty() && self.translate_iter_step == 0 { - // Processing protein by translating DNA - // TODO: Implement iterator over frames instead of hashes_buffer. - - for frame_number in 0..3 { - let substr: Vec = self - .sequence - .iter() - .cloned() - .skip(frame_number) - .take(self.sequence.len() - frame_number) - .collect(); - - let aa = to_aa( - &substr, - self.hash_function.dayhoff(), - self.hash_function.hp(), - ) - .unwrap(); - - aa.windows(self.k_size).for_each(|n| { - let hash = crate::_hash_murmur(n, self.seed); - self.hashes_buffer.push(hash); - }); - - let rc_substr: Vec = self - .dna_rc - .iter() - .cloned() - .skip(frame_number) - .take(self.dna_rc.len() - frame_number) - .collect(); - let aa_rc = to_aa( - &rc_substr, - self.hash_function.dayhoff(), - self.hash_function.hp(), - ) - .unwrap(); - - aa_rc.windows(self.k_size).for_each(|n| { - let hash = crate::_hash_murmur(n, self.seed); - self.hashes_buffer.push(hash); - }); } - Some(Ok(0)) - } else { - if self.translate_iter_step == self.hashes_buffer.len() { - self.hashes_buffer.clear(); - self.kmer_index = self.max_index; - return Some(Ok(0)); - } - let curr_idx = self.translate_iter_step; - self.translate_iter_step += 1; - Some(Ok(self.hashes_buffer[curr_idx])) - } - } else { - // Processing protein - // The kmer size is already divided by 3 - - if self.hash_function.protein() { - let aa_kmer = &self.sequence[self.kmer_index..self.kmer_index + self.k_size]; - let hash = crate::_hash_murmur(aa_kmer, self.seed); - self.kmer_index += 1; - Some(Ok(hash)) - } else { - if !self.prot_configured { - self.aa_seq = match &self.hash_function { - HashFunctions::Murmur64Dayhoff => { - self.sequence.iter().cloned().map(aa_to_dayhoff).collect() - } - HashFunctions::Murmur64Hp => { - self.sequence.iter().cloned().map(aa_to_hp).collect() - } - invalid => { - return Some(Err(Error::InvalidHashFunction { - function: format!("{}", invalid), - })); - } - }; - } - - let aa_kmer = &self.aa_seq[self.kmer_index..self.kmer_index + self.k_size]; - let hash = crate::_hash_murmur(aa_kmer, self.seed); - self.kmer_index += 1; - Some(Ok(hash)) } } - } else { - // End the iterator - None } + + self.kmer_index += 1; + Some(Ok(0)) // Return 0 to indicate progress; actual hashes are stored in `hashes_buffer` } } @@ -1450,13 +1216,13 @@ mod test { let k_size = 7; let seed = 42; let force = true; // Force skip over invalid bases if needed - + let is_protein = false; // Initialize SeqToHashes iterator using the new constructor let mut seq_to_hashes = SeqToHashes::new( sequence, k_size, force, - false, + is_protein, HashFunctions::Murmur64Dna, seed, ); @@ -1489,13 +1255,14 @@ mod test { let k_size = 5; let seed = 42; let force = true; // Force skip over invalid bases if needed + let is_protein = false; // Initialize SeqToHashes iterator using the new constructor let mut seq_to_hashes = SeqToHashes::new( sequence, k_size, force, - false, + is_protein, HashFunctions::Murmur64Skipm2n3, seed, );