Skip to content

Commit

Permalink
cANI Rust Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mr-eyes committed Jan 18, 2024
1 parent 8730072 commit 3b1c1b8
Showing 1 changed file with 245 additions and 0 deletions.
245 changes: 245 additions & 0 deletions src/c_ani.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
use roots::{find_root_brent, SimpleConvergency};
use statrs::distribution::{ContinuousCDF, Normal};
use std::error::Error;

#[derive(Debug)]
#[allow(dead_code)]
struct CiAniResult {
point_estimate: f64,
prob_nothing_in_common: f64,
dist_low: Option<f64>,
dist_high: Option<f64>,
p_threshold: f64,
}

fn exp_n_mutated(l: u64, k: u32, r1: f64) -> f64 {
let q = r1_to_q(k, r1);
l as f64 * q
}

fn var_n_mutated(l: u64, k: u32, r1: f64, q: Option<f64>) -> Result<f64, Box<dyn Error>> {
if r1 == 0.0 {
return Ok(0.0);
}

let q = q.unwrap_or_else(|| r1_to_q(k, r1));

let var_n = l as f64 * (1.0 - q) * (q * (2.0 * k as f64 + (2.0 / r1) - 1.0) - 2.0 * k as f64)
+ k as f64 * (k as f64 - 1.0) * (1.0 - q).powi(2)
+ (2.0 * (1.0 - q) / (r1.powi(2))) * ((1.0 + (k as f64 - 1.0) * (1.0 - q)) * r1 - q);

if var_n < 0.0 {
Err("Error: varN < 0.0!".into())
} else {
Ok(var_n)
}
}

fn exp_n_mutated_squared(l: u64, k: u32, p: f64) -> Result<f64, Box<dyn Error>> {
let var_n = var_n_mutated(l, k, p, None)?;
Ok(var_n + exp_n_mutated(l, k, p).powi(2))
}

fn probit(p: f64) -> f64 {
Normal::new(0.0, 1.0).unwrap().inverse_cdf(p)
}

fn r1_to_q(k: u32, r1: f64) -> f64 {
1.0 - (1.0 - r1).powi(k as i32)
}

fn get_exp_probability_nothing_common(
mutation_rate: f64,
kmer_size: u32,
scaled: u64,
n_unique_kmers: u64,
) -> f64 {
// Inverse of the scale factor, used in probability calculation.
let inverse_scaled = 1.0 / scaled as f64;

// Handle special cases for mutation rate.
if mutation_rate == 1.0 {
1.0
} else if mutation_rate == 0.0 {
0.0
} else {
// Calculate the expected log probability.
let expected_log_probability = get_expected_log_probability(
n_unique_kmers,
kmer_size,
mutation_rate,
inverse_scaled
);
// Return the exponential of the expected log probability.
expected_log_probability.exp()
}
}

fn get_expected_log_probability(
n_unique_kmers: u64,
ksize: u32,
mutation_rate: f64,
scaled_fraction: f64,
) -> f64 {
let exp_nmut = exp_n_mutated(n_unique_kmers, ksize, mutation_rate);
let result = (n_unique_kmers as f64 - exp_nmut) * (1.0 - scaled_fraction).ln();

if result.is_infinite() {
f64::NEG_INFINITY
} else {
result
}
}

fn handle_seqlen_nkmers(
ksize: u32,
sequence_len_bp: Option<u64>,
n_unique_kmers: Option<u64>,
) -> Result<u64, Box<dyn Error>> {
match n_unique_kmers {
Some(n) => Ok(n),
None => match sequence_len_bp {
Some(len) => Ok(len.saturating_sub(ksize as u64 - 1)),
None => Err(
"Error: distance estimation requires 'sequence_len_bp' or 'n_unique_kmers'".into(),
),
},
}
}

fn check_distance(dist: f64) -> Result<f64, Box<dyn Error>> {
if dist >= 0.0 && dist <= 1.0 {
Ok(dist)
} else {
Err(format!("Error: distance value {:.4} is not between 0 and 1!", dist).into())
}
}

impl CiAniResult {
fn new(
point_estimate: f64,
prob_nothing_in_common: f64,
dist_low: Option<f64>,
dist_high: Option<f64>,
p_threshold: f64,
) -> Result<Self, Box<dyn Error>> {
let dist_low_checked = dist_low.map_or(Ok(None), |d| check_distance(d).map(Some))?;
let dist_high_checked = dist_high.map_or(Ok(None), |d| check_distance(d).map(Some))?;

Ok(Self {
point_estimate,
prob_nothing_in_common,
dist_low: dist_low_checked,
dist_high: dist_high_checked,
p_threshold,
})
}
}

fn containment_to_distance(
containment: f64,
ksize: u32,
scaled: u64,
n_unique_kmers: Option<u64>,
sequence_len_bp: Option<u64>,
confidence: Option<f64>,
estimate_ci: Option<bool>,
prob_threshold: Option<f64>,
) -> Result<CiAniResult, Box<dyn Error>> {
let scaled_f64 = scaled as f64;
let n_unique_kmers = handle_seqlen_nkmers(ksize, sequence_len_bp, n_unique_kmers)?;

let confidence = confidence.unwrap_or(0.95);
let estimate_ci = estimate_ci.unwrap_or(false);
let prob_threshold = prob_threshold.unwrap_or(1e-3);

let point_estimate = if containment == 0.0 {
1.0
} else if containment == 1.0 {
0.0
} else {
1.0 - containment.powf(1.0 / ksize as f64)
};

let mut sol1 = None;
let mut sol2 = None;

if estimate_ci {
let alpha = 1.0 - confidence;
let z_alpha = probit(1.0 - alpha / 2.0);
let f_scaled = 1.0 / scaled_f64;
let bias_factor = 1.0 - (1.0 - f_scaled).powi(n_unique_kmers as i32);
let term_1 =
(1.0 - f_scaled) / (f_scaled * (n_unique_kmers as f64).powi(3) * bias_factor.powi(2));
let term_2 = |pest: f64| {
n_unique_kmers as f64 * exp_n_mutated(n_unique_kmers, ksize, pest)
- exp_n_mutated_squared(n_unique_kmers, ksize, pest).unwrap_or(0.0)
};
let term_3 = |pest: f64| {
var_n_mutated(n_unique_kmers, ksize, pest, None).unwrap_or(0.0)
/ (n_unique_kmers as f64).powi(2)
};

let var_direct = |pest: f64| term_1 * term_2(pest) + term_3(pest);

let f1 = |pest: f64| {
(1.0 - pest).powi(ksize as i32) + z_alpha * var_direct(pest).sqrt() - containment
};
let f2 = |pest: f64| {
(1.0 - pest).powi(ksize as i32) - z_alpha * var_direct(pest).sqrt() - containment
};

let mut convergency = SimpleConvergency {
eps: 1e-15,
max_iter: 1000,
};
sol1 = find_root_brent(0.0000001, 0.9999999, &f1, &mut convergency).ok();
sol2 = find_root_brent(0.0000001, 0.9999999, &f2, &mut convergency).ok();
}

let prob_nothing_in_common =
get_exp_probability_nothing_common(point_estimate, ksize, scaled, n_unique_kmers);

CiAniResult::new(
point_estimate,
prob_nothing_in_common,
sol1,
sol2,
prob_threshold,
)
}

/* I calculate the distance, then I do 1 - distance to get the ANI identity.
fn main() {
let containment = 0.1;
let ksize = 31;
let scaled = 1000; // u64
let n_unique_kmers = 1000;
let sequence_len_bp = n_unique_kmers * scaled;
let confidence = Some(0.95);
let estimate_ci = Some(true);
let prob_threshold = Some(1e-3);
let result = containment_to_distance(
containment,
ksize,
scaled,
Some(n_unique_kmers),
Some(sequence_len_bp),
confidence,
estimate_ci,
prob_threshold,
);
match result {
Ok(ci_result) => {
let ani_result = 1.0 - ci_result.point_estimate;
println!("ANI: {}", ani_result);
}
Err(e) => println!("Error occurred: {:?}", e),
}
}
*/

0 comments on commit 3b1c1b8

Please sign in to comment.