diff --git a/pyproject.toml b/pyproject.toml index c523e919c9..2583debf52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ authors = [ { name="Colton Baumler", orcid="0000-0002-5926-7792" }, { name="Olga Botvinnik", orcid="0000-0003-4412-7970" }, { name="Phillip Brooks", orcid="0000-0003-3987-244X" }, + { name="Luca Cappelletti", orcid="0000-0002-1269-2038" }, { name="Peter Cock", orcid="0000-0001-9513-9993" }, { name="Daniel Dsouza", orcid="0000-0001-7843-8596" }, { name="Jade Gardner", orcid="0009-0005-0787-5752" }, diff --git a/src/core/src/sketch/hyperloglog/estimators.rs b/src/core/src/sketch/hyperloglog/estimators.rs index 9a2d7994ef..4f2a01449c 100644 --- a/src/core/src/sketch/hyperloglog/estimators.rs +++ b/src/core/src/sketch/hyperloglog/estimators.rs @@ -1,38 +1,91 @@ -use std::cmp; +use core::{ + cmp, + ops::{Add, AddAssign, Shl, Sub, SubAssign}, +}; pub type CounterType = u8; -pub fn counts(registers: &[CounterType], q: usize) -> Vec { - let mut counts = vec![0; q + 2]; +/// Trait for types that can be used as multiplicity integers. +pub trait MultiplicityInteger: + Shl + + Copy + + AddAssign + + SubAssign + + Eq + + Sub + + Add + + TryFrom + + Ord +{ + /// The zero value. + const ZERO: Self; + /// The one value. + const ONE: Self; + + /// Convert the value to a `f64`. + fn to_f64(self) -> f64; +} + +macro_rules! impl_multiplicity_integer { + ($($t:ty),*) => { + $( + impl MultiplicityInteger for $t { + const ONE: Self = 1; + const ZERO: Self = 0; + + fn to_f64(self) -> f64 { + self as f64 + } + } + )* + }; +} + +impl_multiplicity_integer!(u8, u16, u32); + +pub fn counts(registers: &[CounterType], q: usize) -> Vec { + let mut counts = vec![M::ZERO; q + 2]; for k in registers { - counts[*k as usize] += 1; + counts[*k as usize] += M::ONE; } counts } #[allow(clippy::many_single_char_names)] -pub fn mle(counts: &[u16], p: usize, q: usize, relerr: f64) -> f64 { - let m = 1 << p; +pub fn mle(counts: &[M], p: usize, q: usize, relerr: f64) -> f64 { + let m: M = M::ONE << p; + + // If all of the registers are equal to zero, then we return zero. + if counts[0] == m { + return 0.0; + } + + // If all of the registers are equal to the maximal possible value + // that a register may have, then we return infinity. if counts[q + 1] == m { return f64::INFINITY; } - let (k_min, _) = counts.iter().enumerate().find(|(_, v)| **v != 0).unwrap(); + let (k_min, _) = counts + .iter() + .enumerate() + .find(|(_, v)| **v != M::ZERO) + .unwrap(); let k_min_prime = cmp::max(1, k_min); let (k_max, _) = counts .iter() .enumerate() .rev() - .find(|(_, v)| **v != 0) + .find(|(_, v)| **v != M::ZERO) .unwrap(); let k_max_prime = cmp::min(q, k_max); let mut z = 0.; for i in num_iter::range_step_inclusive(k_max_prime as i32, k_min_prime as i32, -1) { - z = 0.5 * z + counts[i as usize] as f64; + z = 0.5 * z + counts[i as usize].to_f64(); } // ldexp(x, i) = x * (2 ** i) @@ -44,9 +97,9 @@ pub fn mle(counts: &[u16], p: usize, q: usize, relerr: f64) -> f64 { } let mut g_prev = 0.; - let a = z + (counts[0] as f64); - let b = z + (counts[q + 1] as f64) * 2f64.powi(-(q as i32)); - let m_prime = (m - counts[0]) as f64; + let a = z + (counts[0].to_f64()); + let b = z + (counts[q + 1].to_f64()) * 2f64.powi(-(q as i32)); + let m_prime = (m - counts[0]).to_f64(); let mut x = if b <= 1.5 * a { // weak lower bound (47) @@ -57,7 +110,7 @@ pub fn mle(counts: &[u16], p: usize, q: usize, relerr: f64) -> f64 { }; let mut delta_x = x; - let del = relerr / (m as f64).sqrt(); + let del = relerr / m.to_f64().sqrt(); while delta_x > x * del { // secant method iteration @@ -78,13 +131,13 @@ pub fn mle(counts: &[u16], p: usize, q: usize, relerr: f64) -> f64 { } // compare (53) - let mut g = c_prime as f64 * h; + let mut g = c_prime.to_f64() * h; for k in num_iter::range_step_inclusive(k_max_prime as i32 - 1, k_min_prime as i32, -1) { let h_prime = 1. - h; // Calculate h(x/2^k), see (56), at this point x_prime = x / (2^(k+2)) h = (x_prime + h * h_prime) / (x_prime + h_prime); - g += counts[k as usize] as f64 * h; + g += counts[k as usize].to_f64() * h; x_prime += x_prime; } @@ -100,7 +153,7 @@ pub fn mle(counts: &[u16], p: usize, q: usize, relerr: f64) -> f64 { g_prev = g } - m as f64 * x + m.to_f64() * x } /// Calculate the joint maximum likelihood of A and B. @@ -112,35 +165,57 @@ pub fn joint_mle( p: usize, q: usize, ) -> (usize, usize, usize) { - let mut c1 = vec![0; q + 2]; - let mut c2 = vec![0; q + 2]; - let mut cu = vec![0; q + 2]; - let mut cg1 = vec![0; q + 2]; - let mut cg2 = vec![0; q + 2]; - let mut ceq = vec![0; q + 2]; + if p < 8 { + joint_mle_dispatch::(k1, k2, p, q) + } else if p < 16 { + joint_mle_dispatch::(k1, k2, p, q) + } else { + assert!(p == 16 || p == 17 || p == 18); + joint_mle_dispatch::(k1, k2, p, q) + } +} + +/// Calculate the joint maximum likelihood of A and B. +/// +/// Returns a tuple (only in A, only in B, intersection) +fn joint_mle_dispatch( + k1: &[CounterType], + k2: &[CounterType], + p: usize, + q: usize, +) -> (usize, usize, usize) +where + >::Error: std::fmt::Debug, +{ + let mut c1 = vec![M::ZERO; q + 2]; + let mut c2 = vec![M::ZERO; q + 2]; + let mut cu = vec![M::ZERO; q + 2]; + let mut cg1 = vec![M::ZERO; q + 2]; + let mut cg2 = vec![M::ZERO; q + 2]; + let mut ceq = vec![M::ZERO; q + 2]; for (k1_, k2_) in k1.iter().zip(k2.iter()) { match k1_.cmp(k2_) { cmp::Ordering::Less => { - c1[*k1_ as usize] += 1; - cg2[*k2_ as usize] += 1; + c1[*k1_ as usize] += M::ONE; + cg2[*k2_ as usize] += M::ONE; } cmp::Ordering::Greater => { - cg1[*k1_ as usize] += 1; - c2[*k2_ as usize] += 1; + cg1[*k1_ as usize] += M::ONE; + c2[*k2_ as usize] += M::ONE; } cmp::Ordering::Equal => { - ceq[*k1_ as usize] += 1; + ceq[*k1_ as usize] += M::ONE; } } - cu[*cmp::max(k1_, k2_) as usize] += 1; + cu[*cmp::max(k1_, k2_) as usize] += M::ONE; } - for (i, (v, u)) in cg1.iter().zip(ceq.iter()).enumerate() { + for (i, (&v, &u)) in cg1.iter().zip(ceq.iter()).enumerate() { c1[i] += v + u; } - for (i, (v, u)) in cg2.iter().zip(ceq.iter()).enumerate() { + for (i, (&v, &u)) in cg2.iter().zip(ceq.iter()).enumerate() { c2[i] += v + u; } @@ -148,20 +223,22 @@ pub fn joint_mle( let c_bx = mle(&c2, p, q, 0.01); let c_abx = mle(&cu, p, q, 0.01); - let mut counts_axb_half = vec![0u16; q + 2]; - let mut counts_bxa_half = vec![0u16; q + 2]; + let mut counts_axb_half = vec![M::ZERO; q + 2]; + let mut counts_bxa_half = vec![M::ZERO; q + 2]; - counts_axb_half[q] = k1.len() as u16; - counts_bxa_half[q] = k2.len() as u16; + counts_axb_half[q] = M::try_from(k1.len()).unwrap(); + counts_bxa_half[q] = M::try_from(k2.len()).unwrap(); for _q in 0..q { counts_axb_half[_q] = cg1[_q] + ceq[_q] + cg2[_q + 1]; debug_assert!(counts_axb_half[q] >= counts_axb_half[_q]); - counts_axb_half[q] -= counts_axb_half[_q]; + let multiplicity_q = counts_axb_half[_q]; + counts_axb_half[q] -= multiplicity_q; counts_bxa_half[_q] = cg2[_q] + ceq[_q] + cg1[_q + 1]; debug_assert!(counts_bxa_half[q] >= counts_bxa_half[_q]); - counts_bxa_half[q] -= counts_bxa_half[_q]; + let multiplicity_q = counts_bxa_half[_q]; + counts_bxa_half[q] -= multiplicity_q; } let c_axb_half = mle(&counts_axb_half, p, q - 1, 0.01); diff --git a/src/core/src/sketch/hyperloglog/mod.rs b/src/core/src/sketch/hyperloglog/mod.rs index ee09caa6e5..ab26a78adf 100644 --- a/src/core/src/sketch/hyperloglog/mod.rs +++ b/src/core/src/sketch/hyperloglog/mod.rs @@ -81,9 +81,36 @@ impl HyperLogLog { } pub fn cardinality(&self) -> usize { - let counts = estimators::counts(&self.registers, self.q); + if self.p < 8 { + estimators::mle( + &estimators::counts::(&self.registers, self.q), + self.p, + self.q, + 0.01, + ) as usize + } else if self.p < 16 { + estimators::mle( + &estimators::counts::(&self.registers, self.q), + self.p, + self.q, + 0.05, + ) as usize + } else { + assert!(self.p == 16 || self.p == 17 || self.p == 18); + estimators::mle( + &estimators::counts::(&self.registers, self.q), + self.p, + self.q, + 0.1, + ) as usize + } + } - estimators::mle(&counts, self.p, self.q, 0.01) as usize + pub fn union(&self, other: &HyperLogLog) -> usize { + let (only_a, only_b, intersection) = + estimators::joint_mle(&self.registers, &other.registers, self.p, self.q); + + only_a + only_b + intersection } pub fn similarity(&self, other: &HyperLogLog) -> f64 { @@ -224,6 +251,8 @@ impl Update for KmerMinHash { #[cfg(test)] mod test { use std::collections::HashSet; + use std::hash::Hasher; + use std::hash::{DefaultHasher, Hash}; use std::io::{BufReader, BufWriter, Read}; use std::path::PathBuf; @@ -272,13 +301,12 @@ mod test { const N_UNIQUE_H1: usize = 500741; const N_UNIQUE_H2: usize = 995845; const N_UNIQUE_U: usize = 995845; + const INTERSECTION: usize = 500838; const SIMILARITY: f64 = 0.502783; const CONTAINMENT_H1: f64 = 1.; const CONTAINMENT_H2: f64 = 0.502783; - const INTERSECTION: usize = 500838; - let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); filename.push("../../tests/test-data/genome-s10.fa.gz"); @@ -321,6 +349,9 @@ mod test { let abs_error = (1. - (hll2.cardinality() as f64 / N_UNIQUE_H2 as f64)).abs(); assert!(abs_error < ERR_RATE, "{}", abs_error); + let abs_error = (1. - (hll1.union(&hll2) as f64 / N_UNIQUE_U as f64)).abs(); + assert!(abs_error < ERR_RATE, "{}", abs_error); + let similarity = hll1.similarity(&hll2); let abs_error = (1. - (similarity / SIMILARITY)).abs(); assert!(abs_error < ERR_RATE, "{} {}", similarity, SIMILARITY); @@ -374,4 +405,53 @@ mod test { assert_eq!(hll_new.registers, hll.registers); assert_eq!(hll_new.ksize, hll.ksize); } + + #[test] + /// Test to cover corner cases in the MLE calculation + /// that may happen at resolutions 16, 17 or 18, i.e. + /// cases with 2^16 == 65536, 2^17 == 131072, 2^18 == 262144. + /// + /// In such cases, the MLE multiplicities which were earlier + /// implemented always using a u16 type, may overflow. + fn test_mle_corner_cases() { + for precision in [16, 17, 18] { + let mut hll = HyperLogLog::new(precision, 21).unwrap(); + for i in 1..5000 { + let mut hasher = DefaultHasher::new(); + i.hash(&mut hasher); + let hash = hasher.finish(); + hll.add_hash(hash) + } + + let cardinality = hll.cardinality(); + + assert!(cardinality > 4500 && cardinality < 5500); + + // We build a second hll to check whether the union of the two + // hlls is consistent with the cardinality of the union. + let mut hll2 = HyperLogLog::new(precision, 21).unwrap(); + + for i in 5000..10000 { + let mut hasher = DefaultHasher::new(); + i.hash(&mut hasher); + let hash = hasher.finish(); + hll2.add_hash(hash) + } + + let mut hll_union = hll.clone(); + hll_union.merge(&hll2).unwrap(); + let cardinality_union = hll_union.cardinality(); + + assert!( + cardinality_union > 9500 && cardinality_union < 10500, + "precision: {}, cardinality_union: {}", + precision, + cardinality_union + ); + + let intersection = hll.intersection(&hll2); + + assert!(intersection < 500); + } + } }