Skip to content

Commit

Permalink
split ANI addition off from add-cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
bluegenes committed Feb 24, 2024
1 parent 43caeba commit eb53f7c
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 136 deletions.
16 changes: 14 additions & 2 deletions src/multisearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::sync::atomic::AtomicUsize;
use crate::utils::{
csvwriter_thread, load_collection, load_sketches, MultiSearchResult, ReportType,
};
use sourmash::ani_utils::ani_from_containment;

/// Search many queries against a list of signatures.
///
Expand Down Expand Up @@ -56,6 +57,7 @@ pub fn multisearch(
//

let processed_cmp = AtomicUsize::new(0);
let ksize = selection.ksize().unwrap() as f64;

let send = against
.par_iter()
Expand All @@ -74,10 +76,16 @@ pub fn multisearch(
let target_size = against.minhash.size() as f64;

let containment_query_in_target = overlap / query_size;
let containment_in_target = overlap / target_size;
let max_containment = containment_query_in_target.max(containment_in_target);
let containment_target_in_query = overlap / target_size;
let max_containment = containment_query_in_target.max(containment_target_in_query);
let jaccard = overlap / (target_size + query_size - overlap);

// estimate ANI values
let query_ani = ani_from_containment(containment_query_in_target, ksize) * 100.0;
let match_ani = ani_from_containment(containment_target_in_query, ksize) * 100.0;
let average_containment_ani = (query_ani + match_ani) / 2.;
let max_containment_ani = f64::max(query_ani, match_ani);

if containment_query_in_target > threshold {
results.push(MultiSearchResult {
query_name: query.name.clone(),
Expand All @@ -88,6 +96,10 @@ pub fn multisearch(
max_containment,
jaccard,
intersect_hashes: overlap,
query_ani,
match_ani,
average_containment_ani,
max_containment_ani,
})
}
}
Expand Down
12 changes: 12 additions & 0 deletions src/pairwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::sync::atomic::AtomicUsize;
use crate::utils::{
csvwriter_thread, load_collection, load_sketches, MultiSearchResult, ReportType,
};
use sourmash::ani_utils::ani_from_containment;
use sourmash::selection::Selection;
use sourmash::signature::SigsTrait;

Expand Down Expand Up @@ -49,6 +50,7 @@ pub fn pairwise(
// Results written to the writer thread above.

let processed_cmp = AtomicUsize::new(0);
let ksize = selection.ksize().unwrap() as f64;

sketches.par_iter().enumerate().for_each(|(idx, query)| {
for against in sketches.iter().skip(idx + 1) {
Expand All @@ -61,6 +63,12 @@ pub fn pairwise(
let max_containment = containment_q1_in_q2.max(containment_q2_in_q1);
let jaccard = overlap / (query1_size + query2_size - overlap);

// estimate ANI values
let query_ani = ani_from_containment(containment_q1_in_q2, ksize) * 100.0;
let match_ani = ani_from_containment(containment_q2_in_q1, ksize) * 100.0;
let average_containment_ani = (query_ani + match_ani) / 2.;
let max_containment_ani = f64::max(query_ani, match_ani);

if containment_q1_in_q2 > threshold || containment_q2_in_q1 > threshold {
send.send(MultiSearchResult {
query_name: query.name.clone(),
Expand All @@ -71,6 +79,10 @@ pub fn pairwise(
max_containment,
jaccard,
intersect_hashes: overlap,
query_ani,
match_ani,
average_containment_ani,
max_containment_ani,
})
.unwrap();
}
Expand Down
89 changes: 85 additions & 4 deletions src/python/tests/test_multisearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def test_simple(runtmp, zip_query, zip_db):
assert float(row['containment'] == 1.0)
assert float(row['jaccard'] == 1.0)
assert float(row['max_containment'] == 1.0)
assert float(row['query_ani'] == 100.0)
assert float(row['match_ani'] == 100.0)
assert float(row['average_containment_ani'] == 100.0)
assert float(row['max_containment_ani'] == 100.0)

else:
# confirm hand-checked numbers
Expand All @@ -75,23 +79,40 @@ def test_simple(runtmp, zip_query, zip_db):
jaccard = float(row['jaccard'])
maxcont = float(row['max_containment'])
intersect_hashes = int(row['intersect_hashes'])
q1_ani = float(row['query_ani'])
q2_ani = float(row['match_ani'])
avg_ani = float(row['average_containment_ani'])
max_ani = float(row['max_containment_ani'])


jaccard = round(jaccard, 4)
cont = round(cont, 4)
maxcont = round(maxcont, 4)
print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}")
q1_ani = round(q1_ani, 2)
q2_ani = round(q2_ani, 2)
avg_ani = round(avg_ani, 2)
max_ani = round(max_ani, 2)
print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", f"{q1_ani:.04}", f"{q2_ani:.04}", f"{avg_ani:.04}", f"{max_ani:.04}")

if q == 'NC_011665.1' and m == 'NC_009661.1':
assert jaccard == 0.3207
assert cont == 0.4828
assert maxcont == 0.4885
assert intersect_hashes == 2529
assert q1_ani == 97.68
assert q2_ani == 97.72
assert avg_ani == 97.7
assert max_ani == 97.72

if q == 'NC_009661.1' and m == 'NC_011665.1':
assert jaccard == 0.3207
assert cont == 0.4885
assert maxcont == 0.4885
assert intersect_hashes == 2529
assert q1_ani == 97.72
assert q2_ani == 97.68
assert avg_ani == 97.7
assert max_ani == 97.72


@pytest.mark.parametrize("zip_query", [False, True])
Expand Down Expand Up @@ -512,6 +533,10 @@ def test_simple_prot(runtmp):
assert float(row['containment'] == 1.0)
assert float(row['jaccard'] == 1.0)
assert float(row['max_containment'] == 1.0)
assert float(row['query_ani'] == 100.0)
assert float(row['match_ani'] == 100.0)
assert float(row['average_containment_ani'] == 100.0)
assert float(row['max_containment_ani'] == 100.0)

else:
# confirm hand-checked numbers
Expand All @@ -521,23 +546,39 @@ def test_simple_prot(runtmp):
jaccard = float(row['jaccard'])
maxcont = float(row['max_containment'])
intersect_hashes = int(row['intersect_hashes'])
q1_ani = float(row['query_ani'])
q2_ani = float(row['match_ani'])
avg_ani = float(row['average_containment_ani'])
max_ani = float(row['max_containment_ani'])

jaccard = round(jaccard, 4)
cont = round(cont, 4)
maxcont = round(maxcont, 4)
print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes)
q1_ani = round(q1_ani, 2)
q2_ani = round(q2_ani, 2)
avg_ani = round(avg_ani, 2)
max_ani = round(max_ani, 2)
print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes, f"{q1_ani:.04}", f"{q2_ani:.04}", f"{avg_ani:.04}", f"{max_ani:.04}")

if q == 'GCA_001593925' and m == 'GCA_001593935':
assert jaccard == 0.0434
assert cont == 0.1003
assert maxcont == 0.1003
assert intersect_hashes == 342
assert q1_ani == 88.6
assert q2_ani == 87.02
assert avg_ani == 87.81
assert max_ani == 88.6

if q == 'GCA_001593935' and m == 'GCA_001593925':
assert jaccard == 0.0434
assert cont == 0.0712
assert maxcont == 0.1003
assert intersect_hashes == 342
assert q1_ani == 87.02
assert q2_ani == 88.6
assert avg_ani == 87.81
assert max_ani == 88.6


def test_simple_dayhoff(runtmp):
Expand All @@ -564,6 +605,10 @@ def test_simple_dayhoff(runtmp):
assert float(row['containment'] == 1.0)
assert float(row['jaccard'] == 1.0)
assert float(row['max_containment'] == 1.0)
assert float(row['query_ani'] == 100.0)
assert float(row['match_ani'] == 100.0)
assert float(row['average_containment_ani'] == 100.0)
assert float(row['max_containment_ani'] == 100.0)

else:
# confirm hand-checked numbers
Expand All @@ -573,23 +618,39 @@ def test_simple_dayhoff(runtmp):
jaccard = float(row['jaccard'])
maxcont = float(row['max_containment'])
intersect_hashes = int(row['intersect_hashes'])
q1_ani = float(row['query_ani'])
q2_ani = float(row['match_ani'])
avg_ani = float(row['average_containment_ani'])
max_ani = float(row['max_containment_ani'])

jaccard = round(jaccard, 4)
cont = round(cont, 4)
maxcont = round(maxcont, 4)
print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes)
q1_ani = round(q1_ani, 2)
q2_ani = round(q2_ani, 2)
avg_ani = round(avg_ani, 2)
max_ani = round(max_ani, 2)
print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes, f"{q1_ani:.04}", f"{q2_ani:.04}", f"{avg_ani:.04}", f"{max_ani:.04}")

if q == 'GCA_001593925' and m == 'GCA_001593935':
assert jaccard == 0.1326
assert cont == 0.2815
assert maxcont == 0.2815
assert intersect_hashes == 930
assert q1_ani == 93.55
assert q2_ani == 91.89
assert avg_ani == 92.72
assert max_ani == 93.55

if q == 'GCA_001593935' and m == 'GCA_001593925':
assert jaccard == 0.1326
assert cont == 0.2004
assert maxcont == 0.2815
assert intersect_hashes == 930
assert q1_ani == 91.89
assert q2_ani == 93.55
assert avg_ani == 92.72
assert max_ani == 93.55


def test_simple_hp(runtmp):
Expand All @@ -616,6 +677,10 @@ def test_simple_hp(runtmp):
assert float(row['containment'] == 1.0)
assert float(row['jaccard'] == 1.0)
assert float(row['max_containment'] == 1.0)
assert float(row['query_ani'] == 100.0)
assert float(row['match_ani'] == 100.0)
assert float(row['average_containment_ani'] == 100.0)
assert float(row['max_containment_ani'] == 100.0)

else:
# confirm hand-checked numbers
Expand All @@ -625,20 +690,36 @@ def test_simple_hp(runtmp):
jaccard = float(row['jaccard'])
maxcont = float(row['max_containment'])
intersect_hashes = int(row['intersect_hashes'])
q1_ani = float(row['query_ani'])
q2_ani = float(row['match_ani'])
avg_ani = float(row['average_containment_ani'])
max_ani = float(row['max_containment_ani'])

jaccard = round(jaccard, 4)
cont = round(cont, 4)
maxcont = round(maxcont, 4)
print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes)
q1_ani = round(q1_ani, 2)
q2_ani = round(q2_ani, 2)
avg_ani = round(avg_ani, 2)
max_ani = round(max_ani, 2)
print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", intersect_hashes, f"{q1_ani:.04}", f"{q2_ani:.04}", f"{avg_ani:.04}", f"{max_ani:.04}")

if q == 'GCA_001593925' and m == 'GCA_001593935':
assert jaccard == 0.4983
assert cont == 0.747
assert maxcont == 0.747
assert intersect_hashes == 1724
assert q1_ani == 98.48
assert q2_ani == 97.34
assert avg_ani == 97.91
assert max_ani == 98.48

if q == 'GCA_001593935' and m == 'GCA_001593925':
assert jaccard == 0.4983
assert cont == 0.5994
assert maxcont == 0.747
assert intersect_hashes == 1724
assert q1_ani == 97.34
assert q2_ani == 98.48
assert avg_ani == 97.91
assert max_ani == 98.48
Loading

0 comments on commit eb53f7c

Please sign in to comment.