diff --git a/Cargo.lock b/Cargo.lock index 0c74367f..55e1c4d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -206,7 +206,7 @@ dependencies = [ "bitflags", "cexpr", "clang-sys", - "itertools 0.12.1", + "itertools 0.11.0", "lazy_static", "lazycell", "proc-macro2", diff --git a/src/pairwise.rs b/src/pairwise.rs index 110076a6..f5d2d362 100644 --- a/src/pairwise.rs +++ b/src/pairwise.rs @@ -137,7 +137,7 @@ pub fn pairwise( eprintln!("Processed {} comparisons", i); } } - if write_all { + if write_all || output_all_comparisons { let mut query_containment_ani = None; let mut match_containment_ani = None; let mut average_containment_ani = None; diff --git a/src/python/tests/test_manysearch.py b/src/python/tests/test_manysearch.py index 6275b0cf..c39b9831 100644 --- a/src/python/tests/test_manysearch.py +++ b/src/python/tests/test_manysearch.py @@ -113,6 +113,108 @@ def test_simple(runtmp, zip_query, zip_against): assert max_ani == 0.9772 +def test_simple_output_all(runtmp, zip_query, zip_against): + # test basic execution! + query_list = runtmp.output("query.txt") + against_list = runtmp.output("against.txt") + + sig2 = get_test_data("2.fa.sig.gz") + sig47 = get_test_data("47.fa.sig.gz") + sig63 = get_test_data("63.fa.sig.gz") + + make_file_list(query_list, [sig2, sig47, sig63]) + make_file_list(against_list, [sig2, sig47, sig63]) + + output = runtmp.output("out.csv") + + if zip_query: + query_list = zip_siglist(runtmp, query_list, runtmp.output("query.zip")) + if zip_against: + against_list = zip_siglist(runtmp, against_list, runtmp.output("against.zip")) + + runtmp.sourmash( + "scripts", "manysearch", query_list, against_list, "-o", output, "-t", "0.01", "-A" + ) + assert os.path.exists(output) + + df = pandas.read_csv(output) + assert len(df) == 9 + + dd = df.to_dict(orient="index") + print(dd) + + for idx, row in dd.items(): + # identical? + if row["match_name"] == row["query_name"]: + assert row["query_md5"] == row["match_md5"], row + assert float(row["containment"] == 1.0) + assert float(row["jaccard"] == 1.0) + assert float(row["max_containment"] == 1.0) + assert float(row["query_containment_ani"] == 1.0) + assert float(row["match_containment_ani"] == 1.0) + assert float(row["average_containment_ani"] == 1.0) + assert float(row["max_containment_ani"] == 1.0) + + else: + # confirm hand-checked numbers + q = row["query_name"].split()[0] + m = row["match_name"].split()[0] + cont = float(row["containment"]) + jaccard = float(row["jaccard"]) + maxcont = float(row["max_containment"]) + intersect_hashes = int(row["intersect_hashes"]) + query_ani = float(row["query_containment_ani"]) + match_ani = float(row["match_containment_ani"]) + average_ani = float(row["average_containment_ani"]) + max_ani = float(row["max_containment_ani"]) + jaccard = round(jaccard, 4) + cont = round(cont, 4) + maxcont = round(maxcont, 4) + query_ani = round(query_ani, 4) + match_ani = round(match_ani, 4) + average_ani = round(average_ani, 4) + max_ani = round(max_ani, 4) + print( + q, + m, + f"{jaccard:.04}", + f"{cont:.04}", + f"{maxcont:.04}", + f"{query_ani:.04}", + f"{match_ani:.04}", + f"{average_ani:.04}", + f"{max_ani:.04}", + ) + + if q == "NC_011665.1" and m == "NC_009661.1": + assert jaccard == 0.3207 + assert cont == 0.4828 + assert maxcont == 0.4885 + assert intersect_hashes == 2529 + assert query_ani == 0.9768 + assert match_ani == 0.9772 + assert average_ani == 0.977 + assert max_ani == 0.9772 + elif q == "NC_009661.1" and m == "NC_011665.1": + assert jaccard == 0.3207 + assert cont == 0.4885 + assert maxcont == 0.4885 + assert intersect_hashes == 2529 + assert query_ani == 0.9772 + assert match_ani == 0.9768 + assert average_ani == 0.977 + assert max_ani == 0.9772 + else: + assert jaccard == 0 + assert cont == 0 + assert maxcont == 0 + assert intersect_hashes == 0 + assert query_ani == 0 + assert match_ani == 0 + assert average_ani == 0 + assert max_ani == 0 + + def test_simple_abund(runtmp): # test with abund sig sig2 = get_test_data("2.fa.sig.gz") diff --git a/src/python/tests/test_multisearch.py b/src/python/tests/test_multisearch.py index dfc65ee2..eb144f71 100644 --- a/src/python/tests/test_multisearch.py +++ b/src/python/tests/test_multisearch.py @@ -91,6 +91,78 @@ def test_simple_no_ani(runtmp, zip_query, zip_db): assert intersect_hashes == 2529 +def test_simple_no_ani_output_all(runtmp, zip_query, zip_db): + # test basic execution! + query_list = runtmp.output("query.txt") + against_list = runtmp.output("against.txt") + + sig2 = get_test_data("2.fa.sig.gz") + sig47 = get_test_data("47.fa.sig.gz") + sig63 = get_test_data("63.fa.sig.gz") + + make_file_list(query_list, [sig2, sig47, sig63]) + make_file_list(against_list, [sig2, sig47, sig63]) + + output = runtmp.output("out.csv") + + if zip_db: + against_list = zip_siglist(runtmp, against_list, runtmp.output("db.zip")) + if zip_query: + query_list = zip_siglist(runtmp, query_list, runtmp.output("query.zip")) + + runtmp.sourmash("scripts", "multisearch", query_list, against_list, "-o", output, "-A") + assert os.path.exists(output) + + df = pandas.read_csv(output) + assert len(df) == 9 + + dd = df.to_dict(orient="index") + print(dd) + + for idx, row in dd.items(): + assert not ("prob_overlap" in row) + + # identical? + if row["match_name"] == row["query_name"]: + assert row["query_md5"] == row["match_md5"], row + assert float(row["containment"] == 1.0) + assert float(row["jaccard"] == 1.0) + assert float(row["max_containment"] == 1.0) + assert "query_containment_ani" not in row + assert "match_containment_ani" not in row + assert "average_containment_ani" not in row + assert "max_containment_ani" not in row + + else: + # confirm hand-checked numbers + q = row["query_name"].split()[0] + m = row["match_name"].split()[0] + cont = float_round(row["containment"], 4) + jaccard = float_round(row["jaccard"], 4) + maxcont = float_round(row["max_containment"], 4) + + intersect_hashes = int(row["intersect_hashes"]) + + print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}") + + if q == "NC_011665.1" and m == "NC_009661.1": + assert jaccard == 0.3207 + assert cont == 0.4828 + assert maxcont == 0.4885 + assert intersect_hashes == 2529 + + elif q == "NC_009661.1" and m == "NC_011665.1": + assert jaccard == 0.3207 + assert cont == 0.4885 + assert maxcont == 0.4885 + assert intersect_hashes == 2529 + else: + assert jaccard == 0 + assert cont == 0 + assert maxcont == 0 + assert intersect_hashes == 0 + + def test_simple_prob_overlap(runtmp, zip_query, zip_db, indexed_query, indexed_against): # test basic execution! query_list = runtmp.output("query.txt") diff --git a/src/python/tests/test_pairwise.py b/src/python/tests/test_pairwise.py index dab5fd51..9a9bcd94 100644 --- a/src/python/tests/test_pairwise.py +++ b/src/python/tests/test_pairwise.py @@ -91,6 +91,31 @@ def test_simple_no_ani(runtmp, capfd, zip_query, indexed): ) +def test_simple_no_ani_output_all(runtmp, capfd, zip_query, indexed): + # test basic execution! + query_list = runtmp.output("query.txt") + + sig2 = get_test_data("2.fa.sig.gz") + sig47 = get_test_data("47.fa.sig.gz") + sig63 = get_test_data("63.fa.sig.gz") + + make_file_list(query_list, [sig2, sig47, sig63]) + + output = runtmp.output("out.csv") + + if zip_query: + query_list = zip_siglist(runtmp, query_list, runtmp.output("query.zip")) + + if indexed: + query_list = index_siglist(runtmp, query_list, runtmp.output("db")) + + runtmp.sourmash("scripts", "pairwise", query_list, "-o", output, "-t", "-1", "-A") + assert os.path.exists(output) + + df = pandas.read_csv(output) + assert len(df) == 6 + + def test_simple_ani(runtmp, zip_query): # test basic execution! query_list = runtmp.output("query.txt") @@ -158,6 +183,17 @@ def test_simple_ani(runtmp, zip_query): assert q2_ani == 0.9772 assert avg_ani == 0.977 assert max_ani == 0.9772 + elif q == m: + assert jaccard == 1 + else: + assert jaccard == 0 + assert cont == 0 + assert maxcont == 0 + assert intersect_hashes == 0 + assert q1_ani == 0 + assert q2_ani == 0 + assert avg_ani == 0 + assert max_ani == 0 def test_simple_threshold(runtmp, zip_query):