Skip to content

Commit

Permalink
fix uppercasing + validation
Browse files Browse the repository at this point in the history
  • Loading branch information
bluegenes committed Dec 12, 2024
1 parent 6c18180 commit af2af55
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 99 deletions.
193 changes: 95 additions & 98 deletions src/core/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::path::Path;
use std::str;

use cfg_if::cfg_if;
use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -171,7 +172,7 @@ pub enum ReadingFrame {
len: usize, // len gives max_index for kmer iterator
},
Protein {
fw: Vec<u8>, // Only forward frame
fw: Vec<u8>,
len: usize,
},
}
Expand All @@ -198,58 +199,64 @@ impl std::fmt::Display for ReadingFrame {

impl ReadingFrame {
pub fn new_dna(sequence: &[u8]) -> Self {
let fw = sequence.to_vec();
let rc = revcomp(sequence);
let fw = sequence.to_ascii_uppercase();
let rc = revcomp(&fw);
let len = sequence.len();
ReadingFrame::DNA { fw, rc, len }
}

pub fn new_protein(sequence: &[u8], dayhoff: bool, hp: bool) -> Self {
let seq = sequence.to_ascii_uppercase();
let fw: Vec<u8> = if dayhoff {
sequence.iter().map(|&aa| aa_to_dayhoff(aa)).collect()
seq.iter().map(|&aa| aa_to_dayhoff(aa)).collect()
} else if hp {
sequence.iter().map(|&aa| aa_to_hp(aa)).collect()
seq.iter().map(|&aa| aa_to_hp(aa)).collect()
} else {
sequence.to_vec() // protein, as-is.
seq
};

let len = fw.len();
ReadingFrame::Protein { fw, len }
}

pub fn new_skipmer(seq: &[u8], start: usize, m: usize, n: usize) -> Self {
pub fn new_skipmer(sequence: &[u8], start: usize, m: usize, n: usize) -> Self {
let seq = sequence.to_ascii_uppercase();
if start >= n {
panic!("Skipmer frame number must be < n ({})", n);
}
// Generate forward skipmer frame
let fw: Vec<u8> = seq
.iter()
.skip(start)
.enumerate()
.filter_map(|(i, &base)| if i % n < m { Some(base) } else { None })
.collect();
// do we need to round up? (+1)
let mut fw = Vec::with_capacity(((seq.len() * m) + 1) / n);
seq.iter().skip(start).enumerate().for_each(|(i, &base)| {
if i % n < m {
fw.push(base.to_ascii_uppercase());
}
});

let len = fw.len();
let rc = revcomp(&fw);
ReadingFrame::DNA { fw, rc, len }
}

// this is the only one that doesn't uppercase in here b/c more efficient to uppercase externally :/
pub fn new_translated(sequence: &[u8], frame_number: usize, dayhoff: bool, hp: bool) -> Self {
if frame_number > 2 {
panic!("Frame number must be 0, 1, or 2");
}

// translate sequence
let fw: Vec<u8> = sequence
// Translate sequence into amino acids
let mut fw = Vec::with_capacity(sequence.len() / 3);
// NOTE: b/c of chunks(3), we only process full codons and ignore leftover bases (e.g. 1 or 2 at end of frame)
sequence
.iter()
.cloned()
.skip(frame_number) // skip the initial bases for the frame
.take(sequence.len() - frame_number) // adjust length based on skipped bases
.collect::<Vec<u8>>() // collect the DNA subsequence
.chunks(3) // group into codons (triplets)
.filter_map(|codon| to_aa(codon, dayhoff, hp).ok()) // translate each codon
.flatten() // flatten the nested results into a single sequence
.collect();
.skip(frame_number) // Skip the initial bases for the frame
.take(sequence.len() - frame_number) // Adjust length based on skipped bases
.chunks(3) // Group into codons (triplets) using itertools
.into_iter()
.filter_map(|chunk| {
let codon: Vec<u8> = chunk.cloned().collect(); // Collect the chunk into a Vec<u8>
to_aa(&codon, dayhoff, hp).ok() // Translate the codon
})
.for_each(|aa| fw.extend(aa)); // Extend `fw` with amino acids

let len = fw.len();

Expand All @@ -258,6 +265,7 @@ impl ReadingFrame {
}

/// Get the forward sequence.
#[inline]
pub fn fw(&self) -> &[u8] {
match self {
ReadingFrame::DNA { fw, .. } => fw,
Expand All @@ -266,13 +274,15 @@ impl ReadingFrame {
}

/// Get the reverse complement sequence (if DNA).
#[inline]
pub fn rc(&self) -> &[u8] {
match self {
ReadingFrame::DNA { rc, .. } => rc,
_ => panic!("Reverse complement is only available for DNA frames"),
}
}

#[inline]
pub fn length(&self) -> usize {
match self {
ReadingFrame::DNA { len, .. } => *len,
Expand All @@ -294,8 +304,9 @@ pub struct SeqToHashes {
force: bool,
seed: u64,
frames: Vec<ReadingFrame>,
frame_index: usize, // Index of the current frame
kmer_index: usize, // Current k-mer index within the frame
frame_index: usize, // Index of the current frame
kmer_index: usize, // Current k-mer index within the frame
last_position_check: usize, // Index of last base we validated
}

impl SeqToHashes {
Expand All @@ -314,19 +325,17 @@ impl SeqToHashes {
ksize = k_size / 3;
}

// uppercase the sequence. this clones the data bc &[u8] is immutable?
// TODO: could we avoid this by changing revcomp/VALID/etc?
let sequence = seq.to_ascii_uppercase();

// Generate frames based on sequence type and hash function
let frames = if is_protein {
Self::protein_frames(&sequence, &hash_function)
let frames = if hash_function.dna() {
Self::dna_frames(&seq)

Check failure on line 330 in src/core/src/signature.rs

View workflow job for this annotation

GitHub Actions / Lints (stable)

this expression creates a reference which is immediately dereferenced by the compiler

Check failure on line 330 in src/core/src/signature.rs

View workflow job for this annotation

GitHub Actions / Lints (beta)

this expression creates a reference which is immediately dereferenced by the compiler
} else if is_protein {
Self::protein_frames(&seq, &hash_function)

Check failure on line 332 in src/core/src/signature.rs

View workflow job for this annotation

GitHub Actions / Lints (stable)

this expression creates a reference which is immediately dereferenced by the compiler

Check failure on line 332 in src/core/src/signature.rs

View workflow job for this annotation

GitHub Actions / Lints (beta)

this expression creates a reference which is immediately dereferenced by the compiler
} else if hash_function.protein() || hash_function.dayhoff() || hash_function.hp() {
Self::translated_frames(&sequence, &hash_function)
Self::translated_frames(&seq, &hash_function)

Check failure on line 334 in src/core/src/signature.rs

View workflow job for this annotation

GitHub Actions / Lints (stable)

this expression creates a reference which is immediately dereferenced by the compiler

Check failure on line 334 in src/core/src/signature.rs

View workflow job for this annotation

GitHub Actions / Lints (beta)

this expression creates a reference which is immediately dereferenced by the compiler
} else if hash_function.skipm1n3() || hash_function.skipm2n3() {
Self::skipmer_frames(&sequence, &hash_function, ksize)
Self::skipmer_frames(&seq, &hash_function, ksize)

Check failure on line 336 in src/core/src/signature.rs

View workflow job for this annotation

GitHub Actions / Lints (stable)

this expression creates a reference which is immediately dereferenced by the compiler

Check failure on line 336 in src/core/src/signature.rs

View workflow job for this annotation

GitHub Actions / Lints (beta)

this expression creates a reference which is immediately dereferenced by the compiler
} else {
Self::dna_frames(&sequence)
unimplemented!();
};

SeqToHashes {
Expand All @@ -336,6 +345,7 @@ impl SeqToHashes {
frames,
frame_index: 0,
kmer_index: 0,
last_position_check: 0,
}
}

Expand All @@ -355,12 +365,14 @@ impl SeqToHashes {

/// generate translated frames: 6 protein frames
fn translated_frames(seq: &[u8], hash_function: &HashFunctions) -> Vec<ReadingFrame> {
let revcomp_sequence = revcomp(seq);
// since we need to revcomp BEFORE making ReadingFrames, uppercase the sequence here
let sequence = seq.to_ascii_uppercase();
let revcomp_sequence = revcomp(&sequence);
(0..3)
.flat_map(|frame_number| {
vec![
ReadingFrame::new_translated(
seq,
&sequence,
frame_number,
hash_function.dayhoff(),
hash_function.hp(),
Expand Down Expand Up @@ -398,59 +410,6 @@ impl SeqToHashes {
fn out_of_bounds(&self, frame: &ReadingFrame) -> bool {
self.kmer_index + self.k_size > frame.length()
}

// check all bases are valid
fn validate_dna_kmer(&self, kmer: &[u8]) -> Result<bool, Error> {
for &nt in kmer {
if !VALID[nt as usize] {
if self.force {
// Return `false` to indicate invalid k-mer, but do not error out
return Ok(false);
} else {
return Err(Error::InvalidDNA {
message: String::from_utf8_lossy(kmer).to_string(),
});
}
}
}
Ok(true) // All bases are valid
}

/// Process a DNA k-mer, including canonicalization and validation
fn dna_hash(&self, frame: &ReadingFrame) -> Result<u64, Error> {
let kmer = &frame.fw()[self.kmer_index..self.kmer_index + self.k_size];
let rc = frame.rc();

// Validate the k-mer. Skip if invalid and force is true
if !self.validate_dna_kmer(kmer)? {
return Ok(0); // Skip invalid k-mer
}

// For a ksize = 3, and a sequence AGTCGT (len = 6):
// +-+---------+---------------+-------+
// seq RC |i|i + ksize|len - ksize - i|len - i|
// AGTCGT ACGACT +-+---------+---------------+-------+
// +-> +-> |0| 2 | 3 | 6 |
// +-> +-> |1| 3 | 2 | 5 |
// +-> +-> |2| 4 | 1 | 4 |
// +-> +-> |3| 5 | 0 | 3 |
// +-+---------+---------------+-------+
// (leaving this table here because I had to draw to
// get the indices correctly)
let reverse_index = frame.length() - self.k_size - self.kmer_index;
let krc = &rc[reverse_index..reverse_index + self.k_size];

// Compute canonical hash
let canonical_kmer = std::cmp::min(kmer, krc);
let hash = crate::_hash_murmur(canonical_kmer, self.seed);

Ok(hash)
}

fn protein_hash(&self, frame: &ReadingFrame) -> u64 {
let kmer = &frame.fw()[self.kmer_index..self.kmer_index + self.k_size];
crate::_hash_murmur(kmer, self.seed) // build and return hash
}
}

impl Iterator for SeqToHashes {
Expand All @@ -464,22 +423,60 @@ impl Iterator for SeqToHashes {
if self.out_of_bounds(frame) {
self.frame_index += 1;
self.kmer_index = 0; // Reset for the next frame
self.last_position_check = 0;
continue;
}

// Delegate to DNA or protein processing
let result = match frame {
ReadingFrame::DNA { .. } => match self.dna_hash(frame) {
Ok(hash) => Ok(hash), // Valid hash
Err(err) => Err(err), // Error
},
ReadingFrame::Protein { .. } => Ok(self.protein_hash(frame)),
ReadingFrame::DNA { .. } => {
let kmer = &frame.fw()[self.kmer_index..self.kmer_index + self.k_size];
let rc = frame.rc();

// Validate k-mer bases
for j in std::cmp::max(self.kmer_index, self.last_position_check)
..self.kmer_index + self.k_size
{
if !VALID[frame.fw()[j] as usize] {
if !self.force {
// Return an error if force is false
return Some(Err(Error::InvalidDNA {
message: String::from_utf8(kmer.to_vec()).unwrap(),
}));
} else {
// Skip the invalid k-mer
self.kmer_index += 1;
return Some(Ok(0));
}
}
self.last_position_check += 1;
}

// Compute canonical hash
// For a ksize = 3, and a sequence AGTCGT (len = 6):
// +-+---------+---------------+-------+
// seq RC |i|i + ksize|len - ksize - i|len - i|
// AGTCGT ACGACT +-+---------+---------------+-------+
// +-> +-> |0| 2 | 3 | 6 |
// +-> +-> |1| 3 | 2 | 5 |
// +-> +-> |2| 4 | 1 | 4 |
// +-> +-> |3| 5 | 0 | 3 |
// +-+---------+---------------+-------+
// (leaving this table here because I had to draw to
// get the indices correctly)
let krc = &rc[frame.length() - self.k_size - self.kmer_index
..frame.length() - self.kmer_index];
let hash = crate::_hash_murmur(std::cmp::min(kmer, krc), self.seed);
Ok(hash)
}
ReadingFrame::Protein { .. } => {
let kmer = &frame.fw()[self.kmer_index..self.kmer_index + self.k_size];
Ok(crate::_hash_murmur(kmer, self.seed))
}
};

self.kmer_index += 1; // Advance k-mer index
self.kmer_index += 1; // Advance k-mer index for valid k-mers
return Some(result);
}

None // No more frames or k-mers
}
}
Expand Down
8 changes: 7 additions & 1 deletion tests/test_sourmash_sketch.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def test_protein_override_bad_rust_foo():
siglist = factory()
assert len(siglist) == 1
sig = siglist[0]
print(sig.minhash.ksize)

# try adding something
testdata1 = utils.get_test_data("ecoli.faa")
Expand All @@ -354,7 +355,12 @@ def test_protein_override_bad_rust_foo():
with pytest.raises(ValueError) as exc:
sig.add_protein(record.sequence)

assert 'Invalid hash function: "DNA"' in str(exc)
# assert 'Invalid hash function: "DNA"' in str(exc)

# this case now ends up in the "DNA" section of SeqToHashes,
# so we run into the invalid k-mer error
# instead of invalid Hash Function.
assert "invalid DNA character in input k-mer: MRVLKFGGTS" in str(exc)


def test_dayhoff_defaults():
Expand Down

0 comments on commit af2af55

Please sign in to comment.