Skip to content

Commit

Permalink
make ani optional
Browse files Browse the repository at this point in the history
  • Loading branch information
bluegenes committed Feb 24, 2024
1 parent eb53f7c commit e7d9eeb
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 28 deletions.
9 changes: 7 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ fn do_multisearch(
ksize: u8,
scaled: usize,
moltype: String,
estimate_ani: bool,
output_path: Option<String>,
) -> anyhow::Result<u8> {
let selection = build_selection(ksize, scaled, &moltype);
Expand All @@ -217,8 +218,9 @@ fn do_multisearch(
siglist_path,
threshold,
&selection,
output_path,
allow_failed_sigpaths,
estimate_ani,
output_path,
) {
Ok(_) => Ok(0),
Err(e) => {
Expand All @@ -235,16 +237,19 @@ fn do_pairwise(
ksize: u8,
scaled: usize,
moltype: String,
estimate_ani: bool,
output_path: Option<String>,
) -> anyhow::Result<u8> {
let selection = build_selection(ksize, scaled, &moltype);
let allow_failed_sigpaths = true;
eprintln!("{}", estimate_ani);
match pairwise::pairwise(
siglist_path,
threshold,
&selection,
output_path,
allow_failed_sigpaths,
estimate_ani,
output_path,
) {
Ok(_) => Ok(0),
Err(e) => {
Expand Down
19 changes: 14 additions & 5 deletions src/multisearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ pub fn multisearch(
against_filepath: String,
threshold: f64,
selection: &Selection,
output: Option<String>,
allow_failed_sigpaths: bool,
estimate_ani: bool,
output: Option<String>,
) -> Result<(), Box<dyn std::error::Error>> {
// Load all queries into memory at once.

Expand Down Expand Up @@ -81,10 +82,18 @@ pub fn multisearch(
let jaccard = overlap / (target_size + query_size - overlap);

// estimate ANI values
let query_ani = ani_from_containment(containment_query_in_target, ksize) * 100.0;
let match_ani = ani_from_containment(containment_target_in_query, ksize) * 100.0;
let average_containment_ani = (query_ani + match_ani) / 2.;
let max_containment_ani = f64::max(query_ani, match_ani);
let mut query_ani = None;
let mut match_ani = None;
let mut average_containment_ani = None;
let mut max_containment_ani = None;
if estimate_ani {
let qani = ani_from_containment(containment_query_in_target, ksize) * 100.0;
let mani = ani_from_containment(containment_target_in_query, ksize) * 100.0;
query_ani = Some(qani);
match_ani = Some(mani);
average_containment_ani = Some((qani + mani) / 2.);
max_containment_ani = Some(f64::max(qani, mani));
}

if containment_query_in_target > threshold {
results.push(MultiSearchResult {
Expand Down
19 changes: 14 additions & 5 deletions src/pairwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ pub fn pairwise(
siglist: String,
threshold: f64,
selection: &Selection,
output: Option<String>,
allow_failed_sigpaths: bool,
estimate_ani: bool,
output: Option<String>,
) -> Result<(), Box<dyn std::error::Error>> {
// Load all sigs into memory at once.
let collection = load_collection(
Expand Down Expand Up @@ -64,10 +65,18 @@ pub fn pairwise(
let jaccard = overlap / (query1_size + query2_size - overlap);

// estimate ANI values
let query_ani = ani_from_containment(containment_q1_in_q2, ksize) * 100.0;
let match_ani = ani_from_containment(containment_q2_in_q1, ksize) * 100.0;
let average_containment_ani = (query_ani + match_ani) / 2.;
let max_containment_ani = f64::max(query_ani, match_ani);
let mut query_ani = None;
let mut match_ani = None;
let mut average_containment_ani = None;
let mut max_containment_ani = None;
if estimate_ani {
let qani = ani_from_containment(containment_q1_in_q2, ksize) * 100.0;
let mani = ani_from_containment(containment_q2_in_q1, ksize) * 100.0;
query_ani = Some(qani);
match_ani = Some(mani);
average_containment_ani = Some((qani + mani) / 2.);
max_containment_ani = Some(f64::max(qani, mani));
}

if containment_q1_in_q2 > threshold || containment_q2_in_q1 > threshold {
send.send(MultiSearchResult {
Expand Down
6 changes: 6 additions & 0 deletions src/python/sourmash_plugin_branchwater/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ def __init__(self, p):
help = 'molecule type (DNA, protein, dayhoff, or hp; default DNA)')
p.add_argument('-c', '--cores', default=0, type=int,
help='number of cores to use (default is all available)')
p.add_argument('-a', '--ani', action='store_true',
help='estimate ANI from containment')

def main(self, args):
print_version()
Expand All @@ -269,6 +271,7 @@ def main(self, args):
args.ksize,
args.scaled,
args.moltype,
args.ani,
args.output)
if status == 0:
notify(f"...multisearch is done! results in '{args.output}'")
Expand All @@ -294,6 +297,8 @@ def __init__(self, p):
help = 'molecule type (DNA, protein, dayhoff, or hp; default DNA)')
p.add_argument('-c', '--cores', default=0, type=int,
help='number of cores to use (default is all available)')
p.add_argument('-a', '--ani', action='store_true',
help='estimate ANI from containment')

def main(self, args):
print_version()
Expand All @@ -310,6 +315,7 @@ def main(self, args):
args.ksize,
args.scaled,
args.moltype,
args.ani,
args.output)
if status == 0:
notify(f"...pairwise is done! results in '{args.output}'")
Expand Down
79 changes: 74 additions & 5 deletions src/python/tests/test_multisearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def zip_siglist(runtmp, siglist, db):

@pytest.mark.parametrize("zip_query", [False, True])
@pytest.mark.parametrize("zip_db", [False, True])
def test_simple(runtmp, zip_query, zip_db):
def test_simple_no_ani(runtmp, zip_query, zip_db):
# test basic execution!
query_list = runtmp.output('query.txt')
against_list = runtmp.output('against.txt')
Expand Down Expand Up @@ -59,6 +59,76 @@ def test_simple(runtmp, zip_query, zip_db):
dd = df.to_dict(orient='index')
print(dd)

for idx, row in dd.items():
# identical?
if row['match_name'] == row['query_name']:
assert row['query_md5'] == row['match_md5'], row
assert float(row['containment'] == 1.0)
assert float(row['jaccard'] == 1.0)
assert float(row['max_containment'] == 1.0)
assert 'query_ani' not in row
assert 'match_ani' not in row
assert 'average_containment_ani' not in row
assert 'max_containment_ani' not in row

else:
# confirm hand-checked numbers
q = row['query_name'].split()[0]
m = row['match_name'].split()[0]
cont = float(row['containment'])
jaccard = float(row['jaccard'])
maxcont = float(row['max_containment'])
intersect_hashes = int(row['intersect_hashes'])

jaccard = round(jaccard, 4)
cont = round(cont, 4)
maxcont = round(maxcont, 4)
print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}")

if q == 'NC_011665.1' and m == 'NC_009661.1':
assert jaccard == 0.3207
assert cont == 0.4828
assert maxcont == 0.4885
assert intersect_hashes == 2529

if q == 'NC_009661.1' and m == 'NC_011665.1':
assert jaccard == 0.3207
assert cont == 0.4885
assert maxcont == 0.4885
assert intersect_hashes == 2529


@pytest.mark.parametrize("zip_query", [False, True])
@pytest.mark.parametrize("zip_db", [False, True])
def test_simple_ani(runtmp, zip_query, zip_db):
# test basic execution!
query_list = runtmp.output('query.txt')
against_list = runtmp.output('against.txt')

sig2 = get_test_data('2.fa.sig.gz')
sig47 = get_test_data('47.fa.sig.gz')
sig63 = get_test_data('63.fa.sig.gz')

make_file_list(query_list, [sig2, sig47, sig63])
make_file_list(against_list, [sig2, sig47, sig63])

output = runtmp.output('out.csv')

if zip_db:
against_list = zip_siglist(runtmp, against_list, runtmp.output('db.zip'))
if zip_query:
query_list = zip_siglist(runtmp, query_list, runtmp.output('query.zip'))

runtmp.sourmash('scripts', 'multisearch', query_list, against_list,
'-o', output, '--ani')
assert os.path.exists(output)

df = pandas.read_csv(output)
assert len(df) == 5

dd = df.to_dict(orient='index')
print(dd)

for idx, row in dd.items():
# identical?
if row['match_name'] == row['query_name']:
Expand Down Expand Up @@ -114,7 +184,6 @@ def test_simple(runtmp, zip_query, zip_db):
assert avg_ani == 97.7
assert max_ani == 97.72


@pytest.mark.parametrize("zip_query", [False, True])
@pytest.mark.parametrize("zip_db", [False, True])
def test_simple_threshold(runtmp, zip_query, zip_db):
Expand Down Expand Up @@ -517,7 +586,7 @@ def test_simple_prot(runtmp):

runtmp.sourmash('scripts', 'multisearch', sigs, sigs,
'-o', output, '--moltype', 'protein',
'-k', '19', '--scaled', '100')
'-k', '19', '--scaled', '100', '--ani')
assert os.path.exists(output)

df = pandas.read_csv(output)
Expand Down Expand Up @@ -589,7 +658,7 @@ def test_simple_dayhoff(runtmp):

runtmp.sourmash('scripts', 'multisearch', sigs, sigs,
'-o', output, '--moltype', 'dayhoff',
'-k', '19', '--scaled', '100')
'-k', '19', '--scaled', '100', '--ani')
assert os.path.exists(output)

df = pandas.read_csv(output)
Expand Down Expand Up @@ -661,7 +730,7 @@ def test_simple_hp(runtmp):

runtmp.sourmash('scripts', 'multisearch', sigs, sigs,
'-o', output, '--moltype', 'hp',
'-k', '19', '--scaled', '100')
'-k', '19', '--scaled', '100', '--ani')
assert os.path.exists(output)

df = pandas.read_csv(output)
Expand Down
66 changes: 59 additions & 7 deletions src/python/tests/test_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ def zip_siglist(runtmp, siglist, db):
'-o', db)
return db


@pytest.mark.parametrize("zip_query", [False, True])
def test_simple(runtmp, zip_query):
def test_simple_no_ani(runtmp, zip_query):
# test basic execution!
query_list = runtmp.output('query.txt')

Expand All @@ -54,6 +55,57 @@ def test_simple(runtmp, zip_query):
dd = df.to_dict(orient='index')
print(dd)

for idx, row in dd.items():
# confirm hand-checked numbers
q = row['query_name'].split()[0]
m = row['match_name'].split()[0]
cont = float(row['containment'])
jaccard = float(row['jaccard'])
maxcont = float(row['max_containment'])
intersect_hashes = int(row['intersect_hashes'])
assert 'query_ani' not in row
assert 'match_ani' not in row
assert 'average_containment_ani' not in row
assert 'max_containment_ani' not in row

jaccard = round(jaccard, 4)
cont = round(cont, 4)
maxcont = round(maxcont, 4)
print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}")

if q == 'NC_011665.1' and m == 'NC_009661.1':
assert jaccard == 0.3207
assert cont == 0.4828
assert maxcont == 0.4885
assert intersect_hashes == 2529


@pytest.mark.parametrize("zip_query", [False, True])
def test_simple_ani(runtmp, zip_query):
# test basic execution!
query_list = runtmp.output('query.txt')

sig2 = get_test_data('2.fa.sig.gz')
sig47 = get_test_data('47.fa.sig.gz')
sig63 = get_test_data('63.fa.sig.gz')

make_file_list(query_list, [sig2, sig47, sig63])

output = runtmp.output('out.csv')

if zip_query:
query_list = zip_siglist(runtmp, query_list, runtmp.output('query.zip'))

runtmp.sourmash('scripts', 'pairwise', query_list,
'-o', output, '-t', '-1', '--ani')
assert os.path.exists(output)

df = pandas.read_csv(output)
assert len(df) == 3

dd = df.to_dict(orient='index')
print(dd)

for idx, row in dd.items():
# confirm hand-checked numbers
q = row['query_name'].split()[0]
Expand Down Expand Up @@ -329,15 +381,15 @@ def test_md5(runtmp, zip_query):
print(md5s)


def test_simple_prot(runtmp):
def test_simple_prot_ani(runtmp):
# test basic execution with protein sigs
sigs = get_test_data('protein.zip')

output = runtmp.output('out.csv')

runtmp.sourmash('scripts', 'pairwise', sigs,
'-o', output, '--moltype', 'protein',
'-k', '19', '--scaled', '100')
'-k', '19', '--scaled', '100', '--ani')
assert os.path.exists(output)

df = pandas.read_csv(output)
Expand Down Expand Up @@ -379,15 +431,15 @@ def test_simple_prot(runtmp):
assert max_ani == 88.6


def test_simple_dayhoff(runtmp):
def test_simple_dayhoff_ani(runtmp):
# test basic execution with dayhoff sigs
sigs = get_test_data('dayhoff.zip')

output = runtmp.output('out.csv')

runtmp.sourmash('scripts', 'pairwise', sigs,
'-o', output, '--moltype', 'dayhoff',
'-k', '19', '--scaled', '100')
'-k', '19', '--scaled', '100', '--ani')
assert os.path.exists(output)

df = pandas.read_csv(output)
Expand Down Expand Up @@ -429,15 +481,15 @@ def test_simple_dayhoff(runtmp):
assert max_ani == 93.55


def test_simple_hp(runtmp):
def test_simple_hp_ani(runtmp):
# test basic execution with hp sigs
sigs = get_test_data('hp.zip')

output = runtmp.output('out.csv')

runtmp.sourmash('scripts', 'pairwise', sigs,
'-o', output, '--moltype', 'hp',
'-k', '19', '--scaled', '100')
'-k', '19', '--scaled', '100', '--ani')
assert os.path.exists(output)

df = pandas.read_csv(output)
Expand Down
Loading

0 comments on commit e7d9eeb

Please sign in to comment.