From 5edbfc9a95f87a61a31a269e8403ca427fb9d1b1 Mon Sep 17 00:00:00 2001 From: Mitchell Robert Vollger Date: Mon, 9 Sep 2024 13:18:46 -0700 Subject: [PATCH] feat multithreaded acf --- src/subcommands/qc.rs | 11 ++++++-- src/utils/acf.rs | 65 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/src/subcommands/qc.rs b/src/subcommands/qc.rs index b10ef769..abf88ec5 100644 --- a/src/subcommands/qc.rs +++ b/src/subcommands/qc.rs @@ -213,11 +213,16 @@ impl<'a> QcStats<'a> { } log::info!("Calculating m6A auto-correlation."); - let acf = - crate::utils::acf::acf(&self.m6a_acf_starts, Some(self.qc_opts.acf_max_lag), false)?; + let acf = crate::utils::acf::acf_par( + &self.m6a_acf_starts, + Some(self.qc_opts.acf_max_lag), + false, + )?; log::info!("Done calculating m6A auto-correlation!"); for (i, val) in acf.iter().enumerate() { - out.write_all(format!("m6a_acf\t{}\t{}\n", i, val).as_bytes())?; + out.write_all( + format!("m6a_acf\t{}\t{}\n", i, my_ordered_float(*val as f32)).as_bytes(), + )?; } Ok(()) } diff --git a/src/utils/acf.rs b/src/utils/acf.rs index c0b32831..4547adef 100644 --- a/src/utils/acf.rs +++ b/src/utils/acf.rs @@ -2,6 +2,7 @@ // https://github.com/krfricke/arima use anyhow::Result; use num::Float; +use rayon::prelude::*; use std::cmp; use std::convert::From; use std::ops::{Add, AddAssign, Div}; @@ -73,3 +74,67 @@ pub fn acf + From + Copy + Add + AddAssign + Div>( } Ok(y) } + +/// Calculate the auto-correlation function of a time series of length n. +/// but this version is multithreaded over m using rayon. +pub fn acf_par< + T: Float + + From + + From + + Copy + + Add + + AddAssign + + Div + + Send + + std::iter::Sum + + std::marker::Sync, +>( + x: &[T], + max_lag: Option, + covariance: bool, +) -> Result> { + let max_lag = match max_lag { + // if upper bound for max_lag is n-1 + Some(max_lag) => cmp::min(max_lag, x.len() - 1), + None => x.len() - 1, + }; + if x.len() <= max_lag { + return Err(anyhow::anyhow!( + "acf-max-lag ({}) must be less than the number of m6A observations ({}).", + max_lag, + x.len() + )); + } + + let m = max_lag + 1; + + let len_x_usize = x.len(); + let len_x: T = From::from(len_x_usize as u32); + let sum: T = From::from(0.0); + + let sum_x: T = x.iter().fold(sum, |sum, &xi| sum + xi); + let mean_x: T = sum_x / len_x; + + let mut y: Vec = vec![From::from(0.0); m]; + + for t in 0..m { + y[t] = x + .into_par_iter() + .enumerate() + .take(len_x_usize - t) + .map(|(i, xi)| { + let xi = *xi - mean_x; + let xi_t = x[i + t] - mean_x; + (xi * xi_t) / len_x + }) + .sum(); + // we need y[0] to calculate the correlations, so we set it to 1.0 at the end + if !covariance && t > 0 { + y[t] = y[t] / y[0]; + } + } + if !covariance { + y[0] = From::from(1.0); + } + Ok(y) +}