From e7d9eeb1bed06c8351a4b1fa6ab8b98a78048a46 Mon Sep 17 00:00:00 2001 From: "N. Tessa Pierce-Ward" Date: Sat, 24 Feb 2024 12:43:43 -0800 Subject: [PATCH] make ani optional --- src/lib.rs | 9 ++- src/multisearch.rs | 19 +++-- src/pairwise.rs | 19 +++-- .../sourmash_plugin_branchwater/__init__.py | 6 ++ src/python/tests/test_multisearch.py | 79 +++++++++++++++++-- src/python/tests/test_pairwise.py | 66 ++++++++++++++-- src/utils.rs | 13 ++- 7 files changed, 183 insertions(+), 28 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 06930ed4..93f22ee7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -207,6 +207,7 @@ fn do_multisearch( ksize: u8, scaled: usize, moltype: String, + estimate_ani: bool, output_path: Option, ) -> anyhow::Result { let selection = build_selection(ksize, scaled, &moltype); @@ -217,8 +218,9 @@ fn do_multisearch( siglist_path, threshold, &selection, - output_path, allow_failed_sigpaths, + estimate_ani, + output_path, ) { Ok(_) => Ok(0), Err(e) => { @@ -235,16 +237,19 @@ fn do_pairwise( ksize: u8, scaled: usize, moltype: String, + estimate_ani: bool, output_path: Option, ) -> anyhow::Result { let selection = build_selection(ksize, scaled, &moltype); let allow_failed_sigpaths = true; + eprintln!("{}", estimate_ani); match pairwise::pairwise( siglist_path, threshold, &selection, - output_path, allow_failed_sigpaths, + estimate_ani, + output_path, ) { Ok(_) => Ok(0), Err(e) => { diff --git a/src/multisearch.rs b/src/multisearch.rs index 23211f48..df39759e 100644 --- a/src/multisearch.rs +++ b/src/multisearch.rs @@ -21,8 +21,9 @@ pub fn multisearch( against_filepath: String, threshold: f64, selection: &Selection, - output: Option, allow_failed_sigpaths: bool, + estimate_ani: bool, + output: Option, ) -> Result<(), Box> { // Load all queries into memory at once. @@ -81,10 +82,18 @@ pub fn multisearch( let jaccard = overlap / (target_size + query_size - overlap); // estimate ANI values - let query_ani = ani_from_containment(containment_query_in_target, ksize) * 100.0; - let match_ani = ani_from_containment(containment_target_in_query, ksize) * 100.0; - let average_containment_ani = (query_ani + match_ani) / 2.; - let max_containment_ani = f64::max(query_ani, match_ani); + let mut query_ani = None; + let mut match_ani = None; + let mut average_containment_ani = None; + let mut max_containment_ani = None; + if estimate_ani { + let qani = ani_from_containment(containment_query_in_target, ksize) * 100.0; + let mani = ani_from_containment(containment_target_in_query, ksize) * 100.0; + query_ani = Some(qani); + match_ani = Some(mani); + average_containment_ani = Some((qani + mani) / 2.); + max_containment_ani = Some(f64::max(qani, mani)); + } if containment_query_in_target > threshold { results.push(MultiSearchResult { diff --git a/src/pairwise.rs b/src/pairwise.rs index 5dd64b4b..15db72e0 100644 --- a/src/pairwise.rs +++ b/src/pairwise.rs @@ -19,8 +19,9 @@ pub fn pairwise( siglist: String, threshold: f64, selection: &Selection, - output: Option, allow_failed_sigpaths: bool, + estimate_ani: bool, + output: Option, ) -> Result<(), Box> { // Load all sigs into memory at once. let collection = load_collection( @@ -64,10 +65,18 @@ pub fn pairwise( let jaccard = overlap / (query1_size + query2_size - overlap); // estimate ANI values - let query_ani = ani_from_containment(containment_q1_in_q2, ksize) * 100.0; - let match_ani = ani_from_containment(containment_q2_in_q1, ksize) * 100.0; - let average_containment_ani = (query_ani + match_ani) / 2.; - let max_containment_ani = f64::max(query_ani, match_ani); + let mut query_ani = None; + let mut match_ani = None; + let mut average_containment_ani = None; + let mut max_containment_ani = None; + if estimate_ani { + let qani = ani_from_containment(containment_q1_in_q2, ksize) * 100.0; + let mani = ani_from_containment(containment_q2_in_q1, ksize) * 100.0; + query_ani = Some(qani); + match_ani = Some(mani); + average_containment_ani = Some((qani + mani) / 2.); + max_containment_ani = Some(f64::max(qani, mani)); + } if containment_q1_in_q2 > threshold || containment_q2_in_q1 > threshold { send.send(MultiSearchResult { diff --git a/src/python/sourmash_plugin_branchwater/__init__.py b/src/python/sourmash_plugin_branchwater/__init__.py index def6fec7..e1912817 100755 --- a/src/python/sourmash_plugin_branchwater/__init__.py +++ b/src/python/sourmash_plugin_branchwater/__init__.py @@ -252,6 +252,8 @@ def __init__(self, p): help = 'molecule type (DNA, protein, dayhoff, or hp; default DNA)') p.add_argument('-c', '--cores', default=0, type=int, help='number of cores to use (default is all available)') + p.add_argument('-a', '--ani', action='store_true', + help='estimate ANI from containment') def main(self, args): print_version() @@ -269,6 +271,7 @@ def main(self, args): args.ksize, args.scaled, args.moltype, + args.ani, args.output) if status == 0: notify(f"...multisearch is done! results in '{args.output}'") @@ -294,6 +297,8 @@ def __init__(self, p): help = 'molecule type (DNA, protein, dayhoff, or hp; default DNA)') p.add_argument('-c', '--cores', default=0, type=int, help='number of cores to use (default is all available)') + p.add_argument('-a', '--ani', action='store_true', + help='estimate ANI from containment') def main(self, args): print_version() @@ -310,6 +315,7 @@ def main(self, args): args.ksize, args.scaled, args.moltype, + args.ani, args.output) if status == 0: notify(f"...pairwise is done! results in '{args.output}'") diff --git a/src/python/tests/test_multisearch.py b/src/python/tests/test_multisearch.py index 694c2335..e4667fd9 100644 --- a/src/python/tests/test_multisearch.py +++ b/src/python/tests/test_multisearch.py @@ -30,7 +30,7 @@ def zip_siglist(runtmp, siglist, db): @pytest.mark.parametrize("zip_query", [False, True]) @pytest.mark.parametrize("zip_db", [False, True]) -def test_simple(runtmp, zip_query, zip_db): +def test_simple_no_ani(runtmp, zip_query, zip_db): # test basic execution! query_list = runtmp.output('query.txt') against_list = runtmp.output('against.txt') @@ -59,6 +59,76 @@ def test_simple(runtmp, zip_query, zip_db): dd = df.to_dict(orient='index') print(dd) + for idx, row in dd.items(): + # identical? + if row['match_name'] == row['query_name']: + assert row['query_md5'] == row['match_md5'], row + assert float(row['containment'] == 1.0) + assert float(row['jaccard'] == 1.0) + assert float(row['max_containment'] == 1.0) + assert 'query_ani' not in row + assert 'match_ani' not in row + assert 'average_containment_ani' not in row + assert 'max_containment_ani' not in row + + else: + # confirm hand-checked numbers + q = row['query_name'].split()[0] + m = row['match_name'].split()[0] + cont = float(row['containment']) + jaccard = float(row['jaccard']) + maxcont = float(row['max_containment']) + intersect_hashes = int(row['intersect_hashes']) + + jaccard = round(jaccard, 4) + cont = round(cont, 4) + maxcont = round(maxcont, 4) + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}") + + if q == 'NC_011665.1' and m == 'NC_009661.1': + assert jaccard == 0.3207 + assert cont == 0.4828 + assert maxcont == 0.4885 + assert intersect_hashes == 2529 + + if q == 'NC_009661.1' and m == 'NC_011665.1': + assert jaccard == 0.3207 + assert cont == 0.4885 + assert maxcont == 0.4885 + assert intersect_hashes == 2529 + + +@pytest.mark.parametrize("zip_query", [False, True]) +@pytest.mark.parametrize("zip_db", [False, True]) +def test_simple_ani(runtmp, zip_query, zip_db): + # test basic execution! + query_list = runtmp.output('query.txt') + against_list = runtmp.output('against.txt') + + sig2 = get_test_data('2.fa.sig.gz') + sig47 = get_test_data('47.fa.sig.gz') + sig63 = get_test_data('63.fa.sig.gz') + + make_file_list(query_list, [sig2, sig47, sig63]) + make_file_list(against_list, [sig2, sig47, sig63]) + + output = runtmp.output('out.csv') + + if zip_db: + against_list = zip_siglist(runtmp, against_list, runtmp.output('db.zip')) + if zip_query: + query_list = zip_siglist(runtmp, query_list, runtmp.output('query.zip')) + + runtmp.sourmash('scripts', 'multisearch', query_list, against_list, + '-o', output, '--ani') + assert os.path.exists(output) + + df = pandas.read_csv(output) + assert len(df) == 5 + + dd = df.to_dict(orient='index') + print(dd) + for idx, row in dd.items(): # identical? if row['match_name'] == row['query_name']: @@ -114,7 +184,6 @@ def test_simple(runtmp, zip_query, zip_db): assert avg_ani == 97.7 assert max_ani == 97.72 - @pytest.mark.parametrize("zip_query", [False, True]) @pytest.mark.parametrize("zip_db", [False, True]) def test_simple_threshold(runtmp, zip_query, zip_db): @@ -517,7 +586,7 @@ def test_simple_prot(runtmp): runtmp.sourmash('scripts', 'multisearch', sigs, sigs, '-o', output, '--moltype', 'protein', - '-k', '19', '--scaled', '100') + '-k', '19', '--scaled', '100', '--ani') assert os.path.exists(output) df = pandas.read_csv(output) @@ -589,7 +658,7 @@ def test_simple_dayhoff(runtmp): runtmp.sourmash('scripts', 'multisearch', sigs, sigs, '-o', output, '--moltype', 'dayhoff', - '-k', '19', '--scaled', '100') + '-k', '19', '--scaled', '100', '--ani') assert os.path.exists(output) df = pandas.read_csv(output) @@ -661,7 +730,7 @@ def test_simple_hp(runtmp): runtmp.sourmash('scripts', 'multisearch', sigs, sigs, '-o', output, '--moltype', 'hp', - '-k', '19', '--scaled', '100') + '-k', '19', '--scaled', '100', '--ani') assert os.path.exists(output) df = pandas.read_csv(output) diff --git a/src/python/tests/test_pairwise.py b/src/python/tests/test_pairwise.py index 5fd4ec88..cb76da71 100644 --- a/src/python/tests/test_pairwise.py +++ b/src/python/tests/test_pairwise.py @@ -28,8 +28,9 @@ def zip_siglist(runtmp, siglist, db): '-o', db) return db + @pytest.mark.parametrize("zip_query", [False, True]) -def test_simple(runtmp, zip_query): +def test_simple_no_ani(runtmp, zip_query): # test basic execution! query_list = runtmp.output('query.txt') @@ -54,6 +55,57 @@ def test_simple(runtmp, zip_query): dd = df.to_dict(orient='index') print(dd) + for idx, row in dd.items(): + # confirm hand-checked numbers + q = row['query_name'].split()[0] + m = row['match_name'].split()[0] + cont = float(row['containment']) + jaccard = float(row['jaccard']) + maxcont = float(row['max_containment']) + intersect_hashes = int(row['intersect_hashes']) + assert 'query_ani' not in row + assert 'match_ani' not in row + assert 'average_containment_ani' not in row + assert 'max_containment_ani' not in row + + jaccard = round(jaccard, 4) + cont = round(cont, 4) + maxcont = round(maxcont, 4) + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}") + + if q == 'NC_011665.1' and m == 'NC_009661.1': + assert jaccard == 0.3207 + assert cont == 0.4828 + assert maxcont == 0.4885 + assert intersect_hashes == 2529 + + +@pytest.mark.parametrize("zip_query", [False, True]) +def test_simple_ani(runtmp, zip_query): + # test basic execution! + query_list = runtmp.output('query.txt') + + sig2 = get_test_data('2.fa.sig.gz') + sig47 = get_test_data('47.fa.sig.gz') + sig63 = get_test_data('63.fa.sig.gz') + + make_file_list(query_list, [sig2, sig47, sig63]) + + output = runtmp.output('out.csv') + + if zip_query: + query_list = zip_siglist(runtmp, query_list, runtmp.output('query.zip')) + + runtmp.sourmash('scripts', 'pairwise', query_list, + '-o', output, '-t', '-1', '--ani') + assert os.path.exists(output) + + df = pandas.read_csv(output) + assert len(df) == 3 + + dd = df.to_dict(orient='index') + print(dd) + for idx, row in dd.items(): # confirm hand-checked numbers q = row['query_name'].split()[0] @@ -329,7 +381,7 @@ def test_md5(runtmp, zip_query): print(md5s) -def test_simple_prot(runtmp): +def test_simple_prot_ani(runtmp): # test basic execution with protein sigs sigs = get_test_data('protein.zip') @@ -337,7 +389,7 @@ def test_simple_prot(runtmp): runtmp.sourmash('scripts', 'pairwise', sigs, '-o', output, '--moltype', 'protein', - '-k', '19', '--scaled', '100') + '-k', '19', '--scaled', '100', '--ani') assert os.path.exists(output) df = pandas.read_csv(output) @@ -379,7 +431,7 @@ def test_simple_prot(runtmp): assert max_ani == 88.6 -def test_simple_dayhoff(runtmp): +def test_simple_dayhoff_ani(runtmp): # test basic execution with dayhoff sigs sigs = get_test_data('dayhoff.zip') @@ -387,7 +439,7 @@ def test_simple_dayhoff(runtmp): runtmp.sourmash('scripts', 'pairwise', sigs, '-o', output, '--moltype', 'dayhoff', - '-k', '19', '--scaled', '100') + '-k', '19', '--scaled', '100', '--ani') assert os.path.exists(output) df = pandas.read_csv(output) @@ -429,7 +481,7 @@ def test_simple_dayhoff(runtmp): assert max_ani == 93.55 -def test_simple_hp(runtmp): +def test_simple_hp_ani(runtmp): # test basic execution with hp sigs sigs = get_test_data('hp.zip') @@ -437,7 +489,7 @@ def test_simple_hp(runtmp): runtmp.sourmash('scripts', 'pairwise', sigs, '-o', output, '--moltype', 'hp', - '-k', '19', '--scaled', '100') + '-k', '19', '--scaled', '100', '--ani') assert os.path.exists(output) df = pandas.read_csv(output) diff --git a/src/utils.rs b/src/utils.rs index 7b10377a..d9fa1b41 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -726,10 +726,15 @@ pub struct MultiSearchResult { pub max_containment: f64, pub jaccard: f64, pub intersect_hashes: f64, - pub query_ani: f64, - pub match_ani: f64, - pub average_containment_ani: f64, - pub max_containment_ani: f64, + + #[serde(skip_serializing_if = "Option::is_none")] + pub query_ani: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub match_ani: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub average_containment_ani: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_containment_ani: Option, } #[derive(Serialize)]