diff --git a/Cargo.lock b/Cargo.lock index 1e7bcf06..33939daa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1426,6 +1426,7 @@ dependencies = [ "num", "ordered-float", "predicates", + "rand", "rayon", "regex", "rust-htslib", diff --git a/Cargo.toml b/Cargo.toml index f9d1f655..5bd3d940 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,6 +52,7 @@ gzp = "0.11.3" niffler = {version = "2.5.0", default-features = false, features = ["gz"]} burn = { version = "0.12", optional = true, features = ["candle"] } # "wgpu", num = "0.4.3" +rand = "0.8.5" #polars = "0.38" [build-dependencies] diff --git a/src/cli/qc_opts.rs b/src/cli/qc_opts.rs index f21127e3..1b68369a 100644 --- a/src/cli/qc_opts.rs +++ b/src/cli/qc_opts.rs @@ -21,6 +21,9 @@ pub struct QcOpts { /// maximum number of reads to use in the ACF calculation #[clap(long, default_value = "10000")] pub acf_max_reads: usize, + /// After sampling the first "acf-max-reads" randomly sample one of every "acf-sample-rate" reads and replace one of the previous reads at random. + #[clap(long, default_value = "100")] + pub acf_sample_rate: f32, /// In the output include a measure of the number of m6A events per MSPs of a given size. /// The output format is: "m6a_per_msp_size\t{m6A count},{MSP size},{is a FIRE}\t{count}" /// e.g. "m6a_per_msp_size\t35,100,false\t100" diff --git a/src/subcommands/qc.rs b/src/subcommands/qc.rs index 688b7dc4..450d558d 100644 --- a/src/subcommands/qc.rs +++ b/src/subcommands/qc.rs @@ -4,7 +4,9 @@ use crate::utils::bio_io; use anyhow::Result; use itertools::Itertools; use ordered_float::OrderedFloat; +use rand::prelude::*; use std::collections::HashMap; +use std::collections::VecDeque; use std::io::Write; // set the precision of the floats to be saved and printed @@ -56,13 +58,11 @@ pub struct QcStats<'a> { pub rq: HashMap, i64>, // m6a per msp size: (msp size, m6a count, is a FIRE element), number of times seen pub m6a_per_msp_size: HashMap, - // add the m6a acf - pub m6a_acf: HashMap, - // pub - pub acf_read_count: usize, // m6a starts for acf - m6a_acf_starts: Vec, - // + m6a_acf_starts: VecDeque>, + // times m6as have been sampled at random for acf + sampled: usize, + // the qc options for printing qc_opts: &'a QcOpts, // phasing information phased_reads: HashMap, @@ -85,9 +85,8 @@ impl<'a> QcStats<'a> { cpg_count: HashMap::new(), m6a_per_msp_size: HashMap::new(), rq: HashMap::new(), - m6a_acf: HashMap::new(), - acf_read_count: 0, - m6a_acf_starts: Vec::new(), + m6a_acf_starts: VecDeque::new(), + sampled: 0, qc_opts, phased_reads: HashMap::new(), phased_bp: HashMap::new(), @@ -96,13 +95,7 @@ impl<'a> QcStats<'a> { pub fn add_read_to_stats(&mut self, fiber: &fiber::FiberseqData) { // add auto-correlation of m6a - if self.qc_opts.acf - && self.acf_read_count < self.qc_opts.acf_max_reads - && fiber.m6a.starts.len() >= self.qc_opts.acf_min_m6a - { - self.add_m6a_starts_for_acf(fiber); - self.acf_read_count += 1; - } + self.add_m6a_starts_for_acf(fiber); self.full_read_stats(fiber); self.add_basemod_stats(fiber); @@ -114,11 +107,47 @@ impl<'a> QcStats<'a> { /// converts the m6A calls into a boolean vector for the ACF calculation fn add_m6a_starts_for_acf(&mut self, fiber: &fiber::FiberseqData) { + // skip conditions + if !self.qc_opts.acf || fiber.m6a.starts.len() < self.qc_opts.acf_min_m6a { + return; + } + + // test if we should skip or not based on length and random sampling + let rand_float: f32 = random(); + let sample = rand_float < 1.0 / self.qc_opts.acf_sample_rate; + if !(self.m6a_acf_starts.len() <= self.qc_opts.acf_max_reads || sample) { + return; + }; + + // note how many times we have sampled + if sample { + self.sampled += 1; + } + + // add the m6a to the working queue let mut m6a_vec: Vec = vec![0.0; fiber.record.seq_len()]; for m6a in fiber.m6a.starts.iter().flatten() { m6a_vec[*m6a as usize] = 1.0; } - self.m6a_acf_starts.extend(m6a_vec); + + // if we have sampled enough that all reads are random replace + // a random previous read with the current read + if sample && self.sampled > self.qc_opts.acf_max_reads { + let idx = thread_rng().gen_range(0..self.m6a_acf_starts.len()); + self.m6a_acf_starts[idx] = m6a_vec; + log::trace!( + "Replaced read at index {} after the {}th sample", + idx, + self.sampled + ); + return; + } + + // otherwise add to the end while constraining the size of the queue + self.m6a_acf_starts.push_back(m6a_vec); + if self.m6a_acf_starts.len() > self.qc_opts.acf_max_reads { + self.m6a_acf_starts.pop_front(); + } } fn add_ranges(&mut self, fiber: &fiber::FiberseqData) { @@ -233,10 +262,14 @@ impl<'a> QcStats<'a> { if !self.qc_opts.acf { return Ok(()); } - log::info!("Calculating m6A auto-correlation."); let acf = crate::utils::acf::acf_par( - &self.m6a_acf_starts, + &self + .m6a_acf_starts + .iter() + .flatten() + .map(|f| *f) + .collect::>(), Some(self.qc_opts.acf_max_lag), false, )?;