diff --git a/src/sourmash/commands.py b/src/sourmash/commands.py index c5ea6cb4aa..34eefaf8b8 100644 --- a/src/sourmash/commands.py +++ b/src/sourmash/commands.py @@ -680,12 +680,11 @@ def gather(args): weighted_missed = 1 is_abundance = query.minhash.track_abundance and not args.ignore_abundance orig_query_mh = query.minhash - new_max_hash = query.minhash._max_hash next_query = query gather_iter = gather_databases(query, counters, args.threshold_bp, args.ignore_abundance) - for result, weighted_missed, new_max_hash, next_query in gather_iter: + for result, weighted_missed, next_query in gather_iter: if not len(found): # first result? print header. if is_abundance: print_results("") @@ -842,7 +841,7 @@ def multigather(args): found = [] weighted_missed = 1 is_abundance = query.minhash.track_abundance and not args.ignore_abundance - for result, weighted_missed, new_max_hash, next_query in gather_databases(query, counters, args.threshold_bp, args.ignore_abundance): + for result, weighted_missed, next_query in gather_databases(query, counters, args.threshold_bp, args.ignore_abundance): if not len(found): # first result? print header. if is_abundance: print_results("") @@ -919,7 +918,8 @@ def multigather(args): else: notify('saving unassigned hashes to "{}"', output_unassigned) - e = MinHash(ksize=query.minhash.ksize, n=0, max_hash=new_max_hash) + e = MinHash(ksize=query.minhash.ksize, n=0, + scaled=next_query.minhash.scaled) e.add_many(next_query.minhash.hashes) # CTB: note, multigather does not save abundances sig.save_signatures([ sig.SourmashSignature(e) ], fp) diff --git a/src/sourmash/search.py b/src/sourmash/search.py index 8b1093719c..ad8e5a2a54 100644 --- a/src/sourmash/search.py +++ b/src/sourmash/search.py @@ -245,15 +245,6 @@ def search_databases_with_abund_query(query, databases, **kwargs): 'intersect_bp, f_orig_query, f_match, f_unique_to_query, f_unique_weighted, average_abund, median_abund, std_abund, filename, name, md5, match, f_match_orig, unique_intersect_bp, gather_result_rank, remaining_bp') -# build a new query object, subtracting found mins and downsampling -def _subtract_and_downsample(to_remove, old_query, scaled=None): - mh = old_query.minhash - mh = mh.downsample(scaled=scaled) - mh.remove_many(to_remove) - - return SourmashSignature(mh) - - def _find_best(counters, query, threshold_bp): """ Search for the best containment, return precisely one match. @@ -279,16 +270,8 @@ def _find_best(counters, query, threshold_bp): counter.consume(best_intersect_mh) # and done! - return best_result - return None - - -def _filter_max_hash(values, max_hash): - results = set() - for v in values: - if v < max_hash: - results.add(v) - return results + return best_result, best_intersect_mh + return None, None def gather_databases(query, counters, threshold_bp, ignore_abundance): @@ -299,21 +282,21 @@ def gather_databases(query, counters, threshold_bp, ignore_abundance): # track original query information for later usage. track_abundance = query.minhash.track_abundance and not ignore_abundance orig_query_mh = query.minhash - orig_query_hashes = set(orig_query_mh.hashes) # do we pay attention to abundances? - orig_query_abunds = { k: 1 for k in orig_query_hashes } + orig_query_abunds = { k: 1 for k in orig_query_mh.hashes } if track_abundance: import numpy as np orig_query_abunds = orig_query_mh.hashes + orig_query_mh = orig_query_mh.flatten() query.minhash = query.minhash.flatten() cmp_scaled = query.minhash.scaled # initialize with resolution of query result_n = 0 while query.minhash: # find the best match! - best_result = _find_best(counters, query, threshold_bp) + best_result, intersect_mh = _find_best(counters, query, threshold_bp) if not best_result: # no matches at all for this cutoff! notify(f'found less than {format_bp(threshold_bp)} in common. => exiting') @@ -322,10 +305,6 @@ def gather_databases(query, counters, threshold_bp, ignore_abundance): best_match = best_result.signature filename = best_result.location - # subtract found hashes from search hashes, construct new search - query_hashes = set(query.minhash.hashes) - found_hashes = set(best_match.minhash.hashes) - # Is the best match computed with scaled? Die if not. match_scaled = best_match.minhash.scaled assert match_scaled @@ -336,39 +315,37 @@ def gather_databases(query, counters, threshold_bp, ignore_abundance): # eliminate hashes under this new resolution. # (CTB note: this means that if a high scaled/low res signature is # found early on, resolution will be low from then on.) - new_max_hash = _get_max_hash_for_scaled(cmp_scaled) - query_hashes = _filter_max_hash(query_hashes, new_max_hash) - found_hashes = _filter_max_hash(found_hashes, new_max_hash) - orig_query_hashes = _filter_max_hash(orig_query_hashes, new_max_hash) - sum_abunds = sum(( orig_query_abunds[k] for k in orig_query_hashes)) + query_mh = query.minhash.downsample(scaled=cmp_scaled) + found_mh = best_match.minhash.downsample(scaled=cmp_scaled) + orig_query_mh = orig_query_mh.downsample(scaled=cmp_scaled) + sum_abunds = sum(( orig_query_abunds[k] for k in orig_query_mh.hashes )) # calculate intersection with query hashes: - intersect_hashes = query_hashes.intersection(found_hashes) - unique_intersect_bp = cmp_scaled * len(intersect_hashes) - intersect_orig_hashes = orig_query_hashes.intersection(found_hashes) - intersect_bp = cmp_scaled * len(intersect_orig_hashes) + unique_intersect_bp = cmp_scaled * len(intersect_mh) + intersect_orig_mh = orig_query_mh.intersection(found_mh) + intersect_bp = cmp_scaled * len(intersect_orig_mh) # calculate fractions wrt first denominator - genome size - assert intersect_hashes.issubset(found_hashes) - f_match = len(intersect_hashes) / len(found_hashes) - f_orig_query = len(intersect_orig_hashes) / len(orig_query_hashes) + assert intersect_mh.contained_by(found_mh) == 1.0 + f_match = len(intersect_mh) / len(found_mh) + f_orig_query = len(intersect_orig_mh) / len(orig_query_mh) # calculate fractions wrt second denominator - metagenome size - assert intersect_hashes.issubset(orig_query_hashes) - f_unique_to_query = len(intersect_hashes) / len(orig_query_hashes) + assert intersect_mh.contained_by(orig_query_mh) == 1.0 + f_unique_to_query = len(intersect_mh) / len(orig_query_mh) # calculate fraction of subject match with orig query f_match_orig = best_match.minhash.contained_by(orig_query_mh, downsample=True) # calculate scores weighted by abundances - f_unique_weighted = sum((orig_query_abunds[k] for k in intersect_hashes)) + f_unique_weighted = sum((orig_query_abunds[k] for k in intersect_mh.hashes )) f_unique_weighted /= sum_abunds # calculate stats on abundances, if desired. average_abund, median_abund, std_abund = None, None, None if track_abundance: - intersect_abunds = (orig_query_abunds[k] for k in intersect_hashes) + intersect_abunds = (orig_query_abunds[k] for k in intersect_mh.hashes ) intersect_abunds = list(intersect_abunds) average_abund = np.mean(intersect_abunds) @@ -376,11 +353,14 @@ def gather_databases(query, counters, threshold_bp, ignore_abundance): std_abund = np.std(intersect_abunds) # construct a new query, subtracting hashes found in previous one. - query = _subtract_and_downsample(found_hashes, query, cmp_scaled) - remaining_bp = cmp_scaled * len(query.minhash.hashes) + new_query_mh = query.minhash.downsample(scaled=cmp_scaled) + new_query_mh.remove_many(set(found_mh.hashes)) + new_query = SourmashSignature(new_query_mh) + + remaining_bp = cmp_scaled * len(new_query_mh) # compute weighted_missed: - query_hashes -= set(found_hashes) + query_hashes = set(query_mh.hashes) - set(found_mh.hashes) weighted_missed = sum((orig_query_abunds[k] for k in query_hashes)) \ / sum_abunds @@ -403,7 +383,9 @@ def gather_databases(query, counters, threshold_bp, ignore_abundance): remaining_bp=remaining_bp) result_n += 1 - yield result, weighted_missed, new_max_hash, query + yield result, weighted_missed, new_query + + query = new_query ###