From eb53f7c01d4cadbce3210b92ca692ee003a1f67d Mon Sep 17 00:00:00 2001 From: "N. Tessa Pierce-Ward" Date: Sat, 24 Feb 2024 09:45:22 -0800 Subject: [PATCH] split ANI addition off from add-cluster --- src/multisearch.rs | 16 +- src/pairwise.rs | 12 ++ src/python/tests/test_multisearch.py | 89 +++++++++- src/python/tests/test_pairwise.py | 248 +++++++++++++-------------- src/utils.rs | 8 +- 5 files changed, 237 insertions(+), 136 deletions(-) diff --git a/src/multisearch.rs b/src/multisearch.rs index c4f33843..23211f48 100644 --- a/src/multisearch.rs +++ b/src/multisearch.rs @@ -9,6 +9,7 @@ use std::sync::atomic::AtomicUsize; use crate::utils::{ csvwriter_thread, load_collection, load_sketches, MultiSearchResult, ReportType, }; +use sourmash::ani_utils::ani_from_containment; /// Search many queries against a list of signatures. /// @@ -56,6 +57,7 @@ pub fn multisearch( // let processed_cmp = AtomicUsize::new(0); + let ksize = selection.ksize().unwrap() as f64; let send = against .par_iter() @@ -74,10 +76,16 @@ pub fn multisearch( let target_size = against.minhash.size() as f64; let containment_query_in_target = overlap / query_size; - let containment_in_target = overlap / target_size; - let max_containment = containment_query_in_target.max(containment_in_target); + let containment_target_in_query = overlap / target_size; + let max_containment = containment_query_in_target.max(containment_target_in_query); let jaccard = overlap / (target_size + query_size - overlap); + // estimate ANI values + let query_ani = ani_from_containment(containment_query_in_target, ksize) * 100.0; + let match_ani = ani_from_containment(containment_target_in_query, ksize) * 100.0; + let average_containment_ani = (query_ani + match_ani) / 2.; + let max_containment_ani = f64::max(query_ani, match_ani); + if containment_query_in_target > threshold { results.push(MultiSearchResult { query_name: query.name.clone(), @@ -88,6 +96,10 @@ pub fn multisearch( max_containment, jaccard, intersect_hashes: overlap, + query_ani, + match_ani, + average_containment_ani, + max_containment_ani, }) } } diff --git a/src/pairwise.rs b/src/pairwise.rs index aca9f797..5dd64b4b 100644 --- a/src/pairwise.rs +++ b/src/pairwise.rs @@ -7,6 +7,7 @@ use std::sync::atomic::AtomicUsize; use crate::utils::{ csvwriter_thread, load_collection, load_sketches, MultiSearchResult, ReportType, }; +use sourmash::ani_utils::ani_from_containment; use sourmash::selection::Selection; use sourmash::signature::SigsTrait; @@ -49,6 +50,7 @@ pub fn pairwise( // Results written to the writer thread above. let processed_cmp = AtomicUsize::new(0); + let ksize = selection.ksize().unwrap() as f64; sketches.par_iter().enumerate().for_each(|(idx, query)| { for against in sketches.iter().skip(idx + 1) { @@ -61,6 +63,12 @@ pub fn pairwise( let max_containment = containment_q1_in_q2.max(containment_q2_in_q1); let jaccard = overlap / (query1_size + query2_size - overlap); + // estimate ANI values + let query_ani = ani_from_containment(containment_q1_in_q2, ksize) * 100.0; + let match_ani = ani_from_containment(containment_q2_in_q1, ksize) * 100.0; + let average_containment_ani = (query_ani + match_ani) / 2.; + let max_containment_ani = f64::max(query_ani, match_ani); + if containment_q1_in_q2 > threshold || containment_q2_in_q1 > threshold { send.send(MultiSearchResult { query_name: query.name.clone(), @@ -71,6 +79,10 @@ pub fn pairwise( max_containment, jaccard, intersect_hashes: overlap, + query_ani, + match_ani, + average_containment_ani, + max_containment_ani, }) .unwrap(); } diff --git a/src/python/tests/test_multisearch.py b/src/python/tests/test_multisearch.py index 4cb1fd8a..694c2335 100644 --- a/src/python/tests/test_multisearch.py +++ b/src/python/tests/test_multisearch.py @@ -66,6 +66,10 @@ def test_simple(runtmp, zip_query, zip_db): assert float(row['containment'] == 1.0) assert float(row['jaccard'] == 1.0) assert float(row['max_containment'] == 1.0) + assert float(row['query_ani'] == 100.0) + assert float(row['match_ani'] == 100.0) + assert float(row['average_containment_ani'] == 100.0) + assert float(row['max_containment_ani'] == 100.0) else: # confirm hand-checked numbers @@ -75,23 +79,40 @@ def test_simple(runtmp, zip_query, zip_db): jaccard = float(row['jaccard']) maxcont = float(row['max_containment']) intersect_hashes = int(row['intersect_hashes']) + q1_ani = float(row['query_ani']) + q2_ani = float(row['match_ani']) + avg_ani = float(row['average_containment_ani']) + max_ani = float(row['max_containment_ani']) + jaccard = round(jaccard, 4) cont = round(cont, 4) maxcont = round(maxcont, 4) - print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}") + q1_ani = round(q1_ani, 2) + q2_ani = round(q2_ani, 2) + avg_ani = round(avg_ani, 2) + max_ani = round(max_ani, 2) + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", f"{q1_ani:.04}", f"{q2_ani:.04}", f"{avg_ani:.04}", f"{max_ani:.04}") if q == 'NC_011665.1' and m == 'NC_009661.1': assert jaccard == 0.3207 assert cont == 0.4828 assert maxcont == 0.4885 assert intersect_hashes == 2529 + assert q1_ani == 97.68 + assert q2_ani == 97.72 + assert avg_ani == 97.7 + assert max_ani == 97.72 if q == 'NC_009661.1' and m == 'NC_011665.1': assert jaccard == 0.3207 assert cont == 0.4885 assert maxcont == 0.4885 assert intersect_hashes == 2529 + assert q1_ani == 97.72 + assert q2_ani == 97.68 + assert avg_ani == 97.7 + assert max_ani == 97.72 @pytest.mark.parametrize("zip_query", [False, True]) @@ -512,6 +533,10 @@ def test_simple_prot(runtmp): assert float(row['containment'] == 1.0) assert float(row['jaccard'] == 1.0) assert float(row['max_containment'] == 1.0) + assert float(row['query_ani'] == 100.0) + assert float(row['match_ani'] == 100.0) + assert float(row['average_containment_ani'] == 100.0) + assert float(row['max_containment_ani'] == 100.0) else: # confirm hand-checked numbers @@ -521,23 +546,39 @@ def test_simple_prot(runtmp): jaccard = float(row['jaccard']) maxcont = float(row['max_containment']) intersect_hashes = int(row['intersect_hashes']) + q1_ani = float(row['query_ani']) + q2_ani = float(row['match_ani']) + avg_ani = float(row['average_containment_ani']) + max_ani = float(row['max_containment_ani']) jaccard = round(jaccard, 4) cont = round(cont, 4) maxcont = round(maxcont, 4) - print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes) + q1_ani = round(q1_ani, 2) + q2_ani = round(q2_ani, 2) + avg_ani = round(avg_ani, 2) + max_ani = round(max_ani, 2) + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes, f"{q1_ani:.04}", f"{q2_ani:.04}", f"{avg_ani:.04}", f"{max_ani:.04}") if q == 'GCA_001593925' and m == 'GCA_001593935': assert jaccard == 0.0434 assert cont == 0.1003 assert maxcont == 0.1003 assert intersect_hashes == 342 + assert q1_ani == 88.6 + assert q2_ani == 87.02 + assert avg_ani == 87.81 + assert max_ani == 88.6 if q == 'GCA_001593935' and m == 'GCA_001593925': assert jaccard == 0.0434 assert cont == 0.0712 assert maxcont == 0.1003 assert intersect_hashes == 342 + assert q1_ani == 87.02 + assert q2_ani == 88.6 + assert avg_ani == 87.81 + assert max_ani == 88.6 def test_simple_dayhoff(runtmp): @@ -564,6 +605,10 @@ def test_simple_dayhoff(runtmp): assert float(row['containment'] == 1.0) assert float(row['jaccard'] == 1.0) assert float(row['max_containment'] == 1.0) + assert float(row['query_ani'] == 100.0) + assert float(row['match_ani'] == 100.0) + assert float(row['average_containment_ani'] == 100.0) + assert float(row['max_containment_ani'] == 100.0) else: # confirm hand-checked numbers @@ -573,23 +618,39 @@ def test_simple_dayhoff(runtmp): jaccard = float(row['jaccard']) maxcont = float(row['max_containment']) intersect_hashes = int(row['intersect_hashes']) + q1_ani = float(row['query_ani']) + q2_ani = float(row['match_ani']) + avg_ani = float(row['average_containment_ani']) + max_ani = float(row['max_containment_ani']) jaccard = round(jaccard, 4) cont = round(cont, 4) maxcont = round(maxcont, 4) - print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes) + q1_ani = round(q1_ani, 2) + q2_ani = round(q2_ani, 2) + avg_ani = round(avg_ani, 2) + max_ani = round(max_ani, 2) + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes, f"{q1_ani:.04}", f"{q2_ani:.04}", f"{avg_ani:.04}", f"{max_ani:.04}") if q == 'GCA_001593925' and m == 'GCA_001593935': assert jaccard == 0.1326 assert cont == 0.2815 assert maxcont == 0.2815 assert intersect_hashes == 930 + assert q1_ani == 93.55 + assert q2_ani == 91.89 + assert avg_ani == 92.72 + assert max_ani == 93.55 if q == 'GCA_001593935' and m == 'GCA_001593925': assert jaccard == 0.1326 assert cont == 0.2004 assert maxcont == 0.2815 assert intersect_hashes == 930 + assert q1_ani == 91.89 + assert q2_ani == 93.55 + assert avg_ani == 92.72 + assert max_ani == 93.55 def test_simple_hp(runtmp): @@ -616,6 +677,10 @@ def test_simple_hp(runtmp): assert float(row['containment'] == 1.0) assert float(row['jaccard'] == 1.0) assert float(row['max_containment'] == 1.0) + assert float(row['query_ani'] == 100.0) + assert float(row['match_ani'] == 100.0) + assert float(row['average_containment_ani'] == 100.0) + assert float(row['max_containment_ani'] == 100.0) else: # confirm hand-checked numbers @@ -625,20 +690,36 @@ def test_simple_hp(runtmp): jaccard = float(row['jaccard']) maxcont = float(row['max_containment']) intersect_hashes = int(row['intersect_hashes']) + q1_ani = float(row['query_ani']) + q2_ani = float(row['match_ani']) + avg_ani = float(row['average_containment_ani']) + max_ani = float(row['max_containment_ani']) jaccard = round(jaccard, 4) cont = round(cont, 4) maxcont = round(maxcont, 4) - print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes) + q1_ani = round(q1_ani, 2) + q2_ani = round(q2_ani, 2) + avg_ani = round(avg_ani, 2) + max_ani = round(max_ani, 2) + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes, f"{q1_ani:.04}", f"{q2_ani:.04}", f"{avg_ani:.04}", f"{max_ani:.04}") if q == 'GCA_001593925' and m == 'GCA_001593935': assert jaccard == 0.4983 assert cont == 0.747 assert maxcont == 0.747 assert intersect_hashes == 1724 + assert q1_ani == 98.48 + assert q2_ani == 97.34 + assert avg_ani == 97.91 + assert max_ani == 98.48 if q == 'GCA_001593935' and m == 'GCA_001593925': assert jaccard == 0.4983 assert cont == 0.5994 assert maxcont == 0.747 assert intersect_hashes == 1724 + assert q1_ani == 97.34 + assert q2_ani == 98.48 + assert avg_ani == 97.91 + assert max_ani == 98.48 diff --git a/src/python/tests/test_pairwise.py b/src/python/tests/test_pairwise.py index a1991c64..5fd4ec88 100644 --- a/src/python/tests/test_pairwise.py +++ b/src/python/tests/test_pairwise.py @@ -55,38 +55,36 @@ def test_simple(runtmp, zip_query): print(dd) for idx, row in dd.items(): - # identical? - if row['match_name'] == row['query_name']: - assert row['query_md5'] == row['match_md5'], row - assert float(row['containment'] == 1.0) - assert float(row['jaccard'] == 1.0) - assert float(row['max_containment'] == 1.0) - - else: - # confirm hand-checked numbers - q = row['query_name'].split()[0] - m = row['match_name'].split()[0] - cont = float(row['containment']) - jaccard = float(row['jaccard']) - maxcont = float(row['max_containment']) - intersect_hashes = int(row['intersect_hashes']) - - jaccard = round(jaccard, 4) - cont = round(cont, 4) - maxcont = round(maxcont, 4) - print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}") - - if q == 'NC_011665.1' and m == 'NC_009661.1': - assert jaccard == 0.3207 - assert cont == 0.4828 - assert maxcont == 0.4885 - assert intersect_hashes == 2529 - - if q == 'NC_009661.1' and m == 'NC_011665.1': - assert jaccard == 0.3207 - assert cont == 0.4885 - assert maxcont == 0.4885 - assert intersect_hashes == 2529 + # confirm hand-checked numbers + q = row['query_name'].split()[0] + m = row['match_name'].split()[0] + cont = float(row['containment']) + jaccard = float(row['jaccard']) + maxcont = float(row['max_containment']) + intersect_hashes = int(row['intersect_hashes']) + q1_ani = float(row['query_ani']) + q2_ani = float(row['match_ani']) + avg_ani = float(row['average_containment_ani']) + max_ani = float(row['max_containment_ani']) + + jaccard = round(jaccard, 4) + cont = round(cont, 4) + maxcont = round(maxcont, 4) + q1_ani = round(q1_ani, 2) + q2_ani = round(q2_ani, 2) + avg_ani = round(avg_ani, 2) + max_ani = round(max_ani, 2) + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", f"{q1_ani:.04}", f"{q2_ani:.04}", f"{avg_ani:.04}", f"{max_ani:.04}") + + if q == 'NC_011665.1' and m == 'NC_009661.1': + assert jaccard == 0.3207 + assert cont == 0.4828 + assert maxcont == 0.4885 + assert intersect_hashes == 2529 + assert q1_ani == 97.68 + assert q2_ani == 97.72 + assert avg_ani == 97.7 + assert max_ani == 97.72 @pytest.mark.parametrize("zip_query", [False, True]) @@ -349,38 +347,36 @@ def test_simple_prot(runtmp): print(dd) for idx, row in dd.items(): - # identical? - if row['match_name'] == row['query_name']: - assert row['query_md5'] == row['match_md5'], row - assert float(row['containment'] == 1.0) - assert float(row['jaccard'] == 1.0) - assert float(row['max_containment'] == 1.0) - - else: - # confirm hand-checked numbers - q = row['query_name'].split()[0] - m = row['match_name'].split()[0] - cont = float(row['containment']) - jaccard = float(row['jaccard']) - maxcont = float(row['max_containment']) - intersect_hashes = int(row['intersect_hashes']) - - jaccard = round(jaccard, 4) - cont = round(cont, 4) - maxcont = round(maxcont, 4) - print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes) - - if q == 'GCA_001593925' and m == 'GCA_001593935': - assert jaccard == 0.0434 - assert cont == 0.1003 - assert maxcont == 0.1003 - assert intersect_hashes == 342 - - if q == 'GCA_001593935' and m == 'GCA_001593925': - assert jaccard == 0.0434 - assert cont == 0.0712 - assert maxcont == 0.1003 - assert intersect_hashes == 342 + # confirm hand-checked numbers + q = row['query_name'].split()[0] + m = row['match_name'].split()[0] + cont = float(row['containment']) + jaccard = float(row['jaccard']) + maxcont = float(row['max_containment']) + intersect_hashes = int(row['intersect_hashes']) + q1_ani = float(row['query_ani']) + q2_ani = float(row['match_ani']) + avg_ani = float(row['average_containment_ani']) + max_ani = float(row['max_containment_ani']) + + jaccard = round(jaccard, 4) + cont = round(cont, 4) + maxcont = round(maxcont, 4) + q1_ani = round(q1_ani, 2) + q2_ani = round(q2_ani, 2) + avg_ani = round(avg_ani, 2) + max_ani = round(max_ani, 2) + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes, f"{q1_ani:.04}", f"{q2_ani:.04}", f"{avg_ani:.04}", f"{max_ani:.04}") + + if q == 'GCA_001593925' and m == 'GCA_001593935': + assert jaccard == 0.0434 + assert cont == 0.1003 + assert maxcont == 0.1003 + assert intersect_hashes == 342 + assert q1_ani == 88.60 + assert q2_ani == 87.02 + assert avg_ani == 87.81 + assert max_ani == 88.6 def test_simple_dayhoff(runtmp): @@ -401,38 +397,36 @@ def test_simple_dayhoff(runtmp): print(dd) for idx, row in dd.items(): - # identical? - if row['match_name'] == row['query_name']: - assert row['query_md5'] == row['match_md5'], row - assert float(row['containment'] == 1.0) - assert float(row['jaccard'] == 1.0) - assert float(row['max_containment'] == 1.0) - - else: - # confirm hand-checked numbers - q = row['query_name'].split()[0] - m = row['match_name'].split()[0] - cont = float(row['containment']) - jaccard = float(row['jaccard']) - maxcont = float(row['max_containment']) - intersect_hashes = int(row['intersect_hashes']) - - jaccard = round(jaccard, 4) - cont = round(cont, 4) - maxcont = round(maxcont, 4) - print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes) - - if q == 'GCA_001593925' and m == 'GCA_001593935': - assert jaccard == 0.1326 - assert cont == 0.2815 - assert maxcont == 0.2815 - assert intersect_hashes == 930 - - if q == 'GCA_001593935' and m == 'GCA_001593925': - assert jaccard == 0.1326 - assert cont == 0.2004 - assert maxcont == 0.2815 - assert intersect_hashes == 930 + # confirm hand-checked numbers + q = row['query_name'].split()[0] + m = row['match_name'].split()[0] + cont = float(row['containment']) + jaccard = float(row['jaccard']) + maxcont = float(row['max_containment']) + intersect_hashes = int(row['intersect_hashes']) + q1_ani = float(row['query_ani']) + q2_ani = float(row['match_ani']) + avg_ani = float(row['average_containment_ani']) + max_ani = float(row['max_containment_ani']) + + jaccard = round(jaccard, 4) + cont = round(cont, 4) + maxcont = round(maxcont, 4) + q1_ani = round(q1_ani, 2) + q2_ani = round(q2_ani, 2) + avg_ani = round(avg_ani, 2) + max_ani = round(max_ani, 2) + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes, f"{q1_ani:.04}", f"{q2_ani:.04}", f"{avg_ani:.04}", f"{max_ani:.04}") + + if q == 'GCA_001593925' and m == 'GCA_001593935': + assert jaccard == 0.1326 + assert cont == 0.2815 + assert maxcont == 0.2815 + assert intersect_hashes == 930 + assert q1_ani == 93.55 + assert q2_ani == 91.89 + assert avg_ani == 92.72 + assert max_ani == 93.55 def test_simple_hp(runtmp): @@ -453,35 +447,33 @@ def test_simple_hp(runtmp): print(dd) for idx, row in dd.items(): - # identical? - if row['match_name'] == row['query_name']: - assert row['query_md5'] == row['match_md5'], row - assert float(row['containment'] == 1.0) - assert float(row['jaccard'] == 1.0) - assert float(row['max_containment'] == 1.0) - - else: - # confirm hand-checked numbers - q = row['query_name'].split()[0] - m = row['match_name'].split()[0] - cont = float(row['containment']) - jaccard = float(row['jaccard']) - maxcont = float(row['max_containment']) - intersect_hashes = int(row['intersect_hashes']) - - jaccard = round(jaccard, 4) - cont = round(cont, 4) - maxcont = round(maxcont, 4) - print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes) - - if q == 'GCA_001593925' and m == 'GCA_001593935': - assert jaccard == 0.4983 - assert cont == 0.747 - assert maxcont == 0.747 - assert intersect_hashes == 1724 - - if q == 'GCA_001593935' and m == 'GCA_001593925': - assert jaccard == 0.4983 - assert cont == 0.5994 - assert maxcont == 0.747 - assert intersect_hashes == 1724 + # confirm hand-checked numbers + q = row['query_name'].split()[0] + m = row['match_name'].split()[0] + cont = float(row['containment']) + jaccard = float(row['jaccard']) + maxcont = float(row['max_containment']) + intersect_hashes = int(row['intersect_hashes']) + q1_ani = float(row['query_ani']) + q2_ani = float(row['match_ani']) + avg_ani = float(row['average_containment_ani']) + max_ani = float(row['max_containment_ani']) + + jaccard = round(jaccard, 4) + cont = round(cont, 4) + maxcont = round(maxcont, 4) + q1_ani = round(q1_ani, 2) + q2_ani = round(q2_ani, 2) + avg_ani = round(avg_ani, 2) + max_ani = round(max_ani, 2) + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes, f"{q1_ani:.04}", f"{q2_ani:.04}", f"{avg_ani:.04}", f"{max_ani:.04}") + + if q == 'GCA_001593925' and m == 'GCA_001593935': + assert jaccard == 0.4983 + assert cont == 0.747 + assert maxcont == 0.747 + assert intersect_hashes == 1724 + assert q1_ani == 98.48 + assert q2_ani == 97.34 + assert avg_ani == 97.91 + assert max_ani == 98.48 diff --git a/src/utils.rs b/src/utils.rs index e0d01b71..7b10377a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -8,7 +8,7 @@ use camino::Utf8Path as Path; use camino::Utf8PathBuf as PathBuf; use csv::Writer; use serde::ser::Serializer; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::cmp::{Ordering, PartialOrd}; use std::collections::BinaryHeap; use std::fs::{create_dir_all, File}; @@ -716,7 +716,7 @@ pub struct BranchwaterGatherResult { pub intersect_bp: usize, } -#[derive(Serialize)] +#[derive(Serialize, Deserialize)] pub struct MultiSearchResult { pub query_name: String, pub query_md5: String, @@ -726,6 +726,10 @@ pub struct MultiSearchResult { pub max_containment: f64, pub jaccard: f64, pub intersect_hashes: f64, + pub query_ani: f64, + pub match_ani: f64, + pub average_containment_ani: f64, + pub max_containment_ani: f64, } #[derive(Serialize)]