diff --git a/src/manysearch.rs b/src/manysearch.rs index d7ff7808..6368ef17 100644 --- a/src/manysearch.rs +++ b/src/manysearch.rs @@ -9,6 +9,7 @@ use std::sync::atomic; use std::sync::atomic::AtomicUsize; use crate::utils::{csvwriter_thread, load_collection, load_sketches, ReportType, SearchResult}; +use sourmash::ani_utils::ani_from_containment; use sourmash::selection::Selection; use sourmash::signature::SigsTrait; @@ -70,15 +71,28 @@ pub fn manysearch( if let Some(against_mh) = against_sig.minhash() { for query in query_sketchlist.iter() { let overlap = - query.minhash.count_common(against_mh, false).unwrap() as f64; + query.minhash.count_common(against_mh, true).unwrap() as f64; let query_size = query.minhash.size() as f64; let target_size = against_mh.size() as f64; let containment_query_in_target = overlap / query_size; - let containment_in_target = overlap / target_size; + let containment_target_in_query = overlap / target_size; let max_containment = - containment_query_in_target.max(containment_in_target); + containment_query_in_target.max(containment_target_in_query); let jaccard = overlap / (target_size + query_size - overlap); + let qani = ani_from_containment( + containment_query_in_target, + against_mh.ksize() as f64, + ); + let mani = ani_from_containment( + containment_target_in_query, + against_mh.ksize() as f64, + ); + let query_containment_ani = Some(qani); + let match_containment_ani = Some(mani); + let average_containment_ani = Some((qani + mani) / 2.); + let max_containment_ani = Some(f64::max(qani, mani)); + if containment_query_in_target > threshold { results.push(SearchResult { query_name: query.name.clone(), @@ -89,6 +103,10 @@ pub fn manysearch( match_md5: Some(against_sig.md5sum()), jaccard: Some(jaccard), max_containment: Some(max_containment), + query_containment_ani, + match_containment_ani, + average_containment_ani, + max_containment_ani, }); } } diff --git a/src/mastiff_manysearch.rs b/src/mastiff_manysearch.rs index 4f4be0c5..764bf76d 100644 --- a/src/mastiff_manysearch.rs +++ b/src/mastiff_manysearch.rs @@ -2,11 +2,13 @@ use anyhow::Result; use camino::Utf8PathBuf as PathBuf; use rayon::prelude::*; +use std::sync::atomic; +use std::sync::atomic::AtomicUsize; + +use sourmash::ani_utils::ani_from_containment; use sourmash::index::revindex::{RevIndex, RevIndexOps}; use sourmash::selection::Selection; use sourmash::signature::SigsTrait; -use std::sync::atomic; -use std::sync::atomic::AtomicUsize; use crate::utils::{ csvwriter_thread, is_revindex_database, load_collection, ReportType, SearchResult, @@ -74,6 +76,11 @@ pub fn mastiff_manysearch( for (path, overlap) in matches { let containment = overlap as f64 / query_size as f64; if containment >= minimum_containment { + let query_containment_ani = Some(ani_from_containment( + containment, + query_mh.ksize() as f64, + )); + results.push(SearchResult { query_name: query_sig.name(), query_md5: query_sig.md5sum(), @@ -83,6 +90,10 @@ pub fn mastiff_manysearch( match_md5: None, jaccard: None, max_containment: None, + query_containment_ani, + match_containment_ani: None, + average_containment_ani: None, + max_containment_ani: None, }); } } diff --git a/src/python/tests/test_search.py b/src/python/tests/test_search.py index 11c638e7..c3cd79f4 100644 --- a/src/python/tests/test_search.py +++ b/src/python/tests/test_search.py @@ -73,6 +73,10 @@ def test_simple(runtmp, zip_query, zip_against): assert float(row['containment'] == 1.0) assert float(row['jaccard'] == 1.0) assert float(row['max_containment'] == 1.0) + assert float(row['query_containment_ani'] == 1.0) + assert float(row['match_containment_ani'] == 1.0) + assert float(row['average_containment_ani'] == 1.0) + assert float(row['max_containment_ani'] == 1.0) else: # confirm hand-checked numbers @@ -82,23 +86,39 @@ def test_simple(runtmp, zip_query, zip_against): jaccard = float(row['jaccard']) maxcont = float(row['max_containment']) intersect_hashes = int(row['intersect_hashes']) + query_ani = float(row['query_containment_ani']) + match_ani = float(row['match_containment_ani']) + average_ani = float(row['average_containment_ani']) + max_ani = float(row['max_containment_ani']) jaccard = round(jaccard, 4) cont = round(cont, 4) maxcont = round(maxcont, 4) - print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}") + query_ani = round(query_ani, 4) + match_ani = round(match_ani, 4) + average_ani = round(average_ani, 4) + max_ani = round(max_ani, 4) + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", f"{query_ani:.04}", f"{match_ani:.04}", f"{average_ani:.04}", f"{max_ani:.04}") if q == 'NC_011665.1' and m == 'NC_009661.1': assert jaccard == 0.3207 assert cont == 0.4828 assert maxcont == 0.4885 assert intersect_hashes == 2529 + assert query_ani == 0.9768 + assert match_ani == 0.9772 + assert average_ani == 0.977 + assert max_ani == 0.9772 if q == 'NC_009661.1' and m == 'NC_011665.1': assert jaccard == 0.3207 assert cont == 0.4885 assert maxcont == 0.4885 assert intersect_hashes == 2529 + assert query_ani == 0.9772 + assert match_ani == 0.9768 + assert average_ani == 0.977 + assert max_ani == 0.9772 @pytest.mark.parametrize("zip_query", [False, True]) @@ -135,23 +155,27 @@ def test_simple_indexed(runtmp, zip_query): # identical? if row['match_name'] == row['query_name']: assert float(row['containment'] == 1.0) + assert float(row['query_containment_ani'] == 1.0) else: # confirm hand-checked numbers q = row['query_name'].split()[0] m = row['match_name'].split()[0] cont = float(row['containment']) intersect_hashes = int(row['intersect_hashes']) - + query_ani = float(row['query_containment_ani']) cont = round(cont, 4) - print(q, m, f"{cont:.04}") + query_ani = round(query_ani, 4) + print(q, m, f"{cont:.04}", f"{query_ani:.04}") if q == 'NC_011665.1' and m == 'NC_009661.1': assert cont == 0.4828 assert intersect_hashes == 2529 + assert query_ani == 0.9768 if q == 'NC_009661.1' and m == 'NC_011665.1': assert cont == 0.4885 assert intersect_hashes == 2529 + assert query_ani == 0.9772 @pytest.mark.parametrize("indexed", [False, True]) @@ -629,6 +653,10 @@ def test_simple_protein(runtmp): assert float(row['containment'] == 1.0) assert float(row['jaccard'] == 1.0) assert float(row['max_containment'] == 1.0) + assert float(row['query_containment_ani']) == 1.0 + assert float(row['match_containment_ani']) == 1.0 + assert float(row['average_containment_ani']) == 1.0 + assert float(row['max_containment_ani']) == 1.0 else: # confirm hand-checked numbers q = row['query_name'].split()[0] @@ -637,23 +665,39 @@ def test_simple_protein(runtmp): jaccard = float(row['jaccard']) maxcont = float(row['max_containment']) intersect_hashes = int(row['intersect_hashes']) + query_ani = float(row['query_containment_ani']) + match_ani = float(row['match_containment_ani']) + average_ani = float(row['average_containment_ani']) + max_ani = float(row['max_containment_ani']) jaccard = round(jaccard, 4) cont = round(cont, 4) maxcont = round(maxcont, 4) - print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes) + query_ani = round(query_ani, 4) + match_ani = round(match_ani, 4) + average_ani = round(average_ani, 4) + max_ani = round(max_ani, 4) + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes, f"{query_ani:.04}", f"{match_ani:.04}", f"{average_ani:.04}", f"{max_ani:.04}") if q == 'GCA_001593925' and m == 'GCA_001593935': assert jaccard == 0.0434 assert cont == 0.1003 assert maxcont == 0.1003 assert intersect_hashes == 342 + assert query_ani == 0.9605 + assert match_ani == 0.9547 + assert average_ani == 0.9576 + assert max_ani == 0.9605 if q == 'GCA_001593935' and m == 'GCA_001593925': assert jaccard == 0.0434 assert cont == 0.0712 assert maxcont == 0.1003 assert intersect_hashes == 342 + assert query_ani == 0.9547 + assert match_ani == 0.9605 + assert average_ani == 0.9576 + assert max_ani == 0.9605 def test_simple_protein_indexed(runtmp): @@ -681,23 +725,28 @@ def test_simple_protein_indexed(runtmp): # identical? if row['match_name'] == row['query_name']: assert float(row['containment'] == 1.0) + assert float(row['query_containment_ani'] == 1.0) else: # confirm hand-checked numbers q = row['query_name'].split()[0] m = row['match_name'].split()[0] cont = float(row['containment']) + query_ani = float(row['query_containment_ani']) intersect_hashes = int(row['intersect_hashes']) cont = round(cont, 4) - print(q, m, f"{cont:.04}", intersect_hashes) + query_ani = round(query_ani, 4) + print(q, m, f"{cont:.04}", intersect_hashes, f"{query_ani:.04}") if q == 'GCA_001593925' and m == 'GCA_001593935': assert cont == 0.1003 assert intersect_hashes == 342 + assert query_ani == 0.9605 if q == 'GCA_001593935' and m == 'GCA_001593925': assert cont == 0.0712 assert intersect_hashes == 342 + assert query_ani == 0.9547 def test_simple_dayhoff(runtmp): @@ -725,6 +774,11 @@ def test_simple_dayhoff(runtmp): assert float(row['containment'] == 1.0) assert float(row['jaccard'] == 1.0) assert float(row['max_containment'] == 1.0) + assert float(row['query_containment_ani']) == 1.0 + assert float(row['match_containment_ani']) == 1.0 + assert float(row['average_containment_ani']) == 1.0 + assert float(row['max_containment_ani']) == 1.0 + else: # confirm hand-checked numbers q = row['query_name'].split()[0] @@ -733,23 +787,39 @@ def test_simple_dayhoff(runtmp): jaccard = float(row['jaccard']) maxcont = float(row['max_containment']) intersect_hashes = int(row['intersect_hashes']) + query_ani = float(row['query_containment_ani']) + match_ani = float(row['match_containment_ani']) + average_ani = float(row['average_containment_ani']) + max_ani = float(row['max_containment_ani']) jaccard = round(jaccard, 4) cont = round(cont, 4) maxcont = round(maxcont, 4) - print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes) + query_ani = round(query_ani, 4) + match_ani = round(match_ani, 4) + average_ani = round(average_ani, 4) + max_ani = round(max_ani, 4) + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes, f"{query_ani:.04}", f"{match_ani:.04}", f"{average_ani:.04}", f"{max_ani:.04}") if q == 'GCA_001593925' and m == 'GCA_001593935': assert jaccard == 0.1326 assert cont == 0.2815 assert maxcont == 0.2815 assert intersect_hashes == 930 + assert query_ani == 0.978 + assert match_ani == 0.9722 + assert average_ani == 0.9751 + assert max_ani == 0.978 if q == 'GCA_001593935' and m == 'GCA_001593925': assert jaccard == 0.1326 assert cont == 0.2004 assert maxcont == 0.2815 assert intersect_hashes == 930 + assert query_ani == 0.9722 + assert match_ani == 0.978 + assert average_ani == 0.9751 + assert max_ani == 0.978 def test_simple_dayhoff_indexed(runtmp): @@ -777,23 +847,28 @@ def test_simple_dayhoff_indexed(runtmp): # identical? if row['match_name'] == row['query_name']: assert float(row['containment'] == 1.0) + assert float(row['query_containment_ani'] == 1.0) else: # confirm hand-checked numbers q = row['query_name'].split()[0] m = row['match_name'].split()[0] cont = float(row['containment']) + query_ani = float(row['query_containment_ani']) intersect_hashes = int(row['intersect_hashes']) cont = round(cont, 4) - print(q, m, f"{cont:.04}", intersect_hashes) + query_ani = round(query_ani, 4) + print(q, m, f"{cont:.04}", intersect_hashes, f"{query_ani:.04}") if q == 'GCA_001593925' and m == 'GCA_001593935': assert cont == 0.2815 assert intersect_hashes == 930 + assert query_ani == 0.978 if q == 'GCA_001593935' and m == 'GCA_001593925': assert cont == 0.2004 assert intersect_hashes == 930 + assert query_ani == 0.9722 def test_simple_hp(runtmp): @@ -821,6 +896,10 @@ def test_simple_hp(runtmp): assert float(row['containment'] == 1.0) assert float(row['jaccard'] == 1.0) assert float(row['max_containment'] == 1.0) + assert float(row['query_containment_ani']) == 1.0 + assert float(row['match_containment_ani']) == 1.0 + assert float(row['average_containment_ani']) == 1.0 + assert float(row['max_containment_ani']) == 1.0 else: # confirm hand-checked numbers q = row['query_name'].split()[0] @@ -829,23 +908,39 @@ def test_simple_hp(runtmp): jaccard = float(row['jaccard']) maxcont = float(row['max_containment']) intersect_hashes = int(row['intersect_hashes']) + query_ani = float(row['query_containment_ani']) + match_ani = float(row['match_containment_ani']) + average_ani = float(row['average_containment_ani']) + max_ani = float(row['max_containment_ani']) jaccard = round(jaccard, 4) cont = round(cont, 4) maxcont = round(maxcont, 4) - print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes) + query_ani = round(query_ani, 4) + match_ani = round(match_ani, 4) + average_ani = round(average_ani, 4) + max_ani = round(max_ani, 4) + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes, f"{query_ani:.04}", f"{match_ani:.04}", f"{average_ani:.04}", f"{max_ani:.04}") if q == 'GCA_001593925' and m == 'GCA_001593935': assert jaccard == 0.4983 assert cont == 0.747 assert maxcont == 0.747 assert intersect_hashes == 1724 + assert query_ani == 0.9949 + assert match_ani == 0.9911 + assert average_ani == 0.993 + assert max_ani == 0.9949 if q == 'GCA_001593935' and m == 'GCA_001593925': assert jaccard == 0.4983 assert cont == 0.5994 assert maxcont == 0.747 assert intersect_hashes == 1724 + assert query_ani == 0.9911 + assert match_ani == 0.9949 + assert average_ani == 0.993 + assert max_ani == 0.9949 def test_simple_hp_indexed(runtmp): @@ -873,20 +968,26 @@ def test_simple_hp_indexed(runtmp): # identical? if row['match_name'] == row['query_name']: assert float(row['containment'] == 1.0) + assert float(row['query_containment_ani']) == 1.0 + else: # confirm hand-checked numbers q = row['query_name'].split()[0] m = row['match_name'].split()[0] cont = float(row['containment']) intersect_hashes = int(row['intersect_hashes']) + query_ani = float(row['query_containment_ani']) cont = round(cont, 4) - print(q, m, f"{cont:.04}", intersect_hashes) + query_ani = round(query_ani, 4) + print(q, m, f"{cont:.04}", intersect_hashes, f"{query_ani:.04}") if q == 'GCA_001593925' and m == 'GCA_001593935': assert cont == 0.747 assert intersect_hashes == 1724 + assert query_ani == 0.9949 if q == 'GCA_001593935' and m == 'GCA_001593925': assert cont == 0.5994 assert intersect_hashes == 1724 + assert query_ani == 0.9911 diff --git a/src/utils.rs b/src/utils.rs index d98b503b..11de1ea7 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -897,6 +897,14 @@ pub struct SearchResult { pub match_md5: Option, pub jaccard: Option, pub max_containment: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub query_containment_ani: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub match_containment_ani: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub average_containment_ani: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_containment_ani: Option, } #[derive(Serialize)]