Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] re-implement the actual gather protocol with a cleaner interface. #1489

Merged
merged 19 commits into from
May 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,17 +655,14 @@ def gather(args):
# @CTB experimental! w00t fun!
if args.prefetch or 1:
notify(f"Using EXPERIMENTAL feature: prefetch enabled!")
from .index import LinearIndex, CounterGatherIndex
prefetch_idx = CounterGatherIndex(query)

prefetch_query = copy.copy(query)
prefetch_query.minhash = prefetch_query.minhash.flatten()

counters = []
for db in databases:
for match in db.prefetch(prefetch_query, args.threshold_bp):
prefetch_idx.insert(match.signature, location=match.location)

databases = [ prefetch_idx ]
counter = db.counter_gather(prefetch_query, args.threshold_bp)
counters.append(counter)

found = []
weighted_missed = 1
Expand All @@ -674,7 +671,7 @@ def gather(args):
new_max_hash = query.minhash._max_hash
next_query = query

gather_iter = gather_databases(query, databases, args.threshold_bp,
gather_iter = gather_databases(query, counters, args.threshold_bp,
args.ignore_abundance)
for result, weighted_missed, new_max_hash, next_query in gather_iter:
if not len(found): # first result? print header.
Expand Down Expand Up @@ -821,10 +818,20 @@ def multigather(args):
error('no query hashes!? skipping to next..')
continue

notify(f"Using EXPERIMENTAL feature: prefetch enabled!")
counters = []
prefetch_query = copy.copy(query)
prefetch_query.minhash = prefetch_query.minhash.flatten()

counters = []
for db in databases:
counter = db.counter_gather(prefetch_query, args.threshold_bp)
counters.append(counter)

found = []
weighted_missed = 1
is_abundance = query.minhash.track_abundance and not args.ignore_abundance
for result, weighted_missed, new_max_hash, next_query in gather_databases(query, databases, args.threshold_bp, args.ignore_abundance):
for result, weighted_missed, new_max_hash, next_query in gather_databases(query, counters, args.threshold_bp, args.ignore_abundance):
if not len(found): # first result? print header.
if is_abundance:
print_results("")
Expand Down
224 changes: 131 additions & 93 deletions src/sourmash/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import abstractmethod, ABC
from collections import namedtuple, Counter
import zipfile
import copy

from .search import make_jaccard_search_query, make_gather_query

Expand Down Expand Up @@ -214,6 +215,27 @@ def gather(self, query, threshold_bp=None, **kwargs):

return results[:1]

def counter_gather(self, query, threshold_bp, **kwargs):
"""Returns an object that permits 'gather' on top of the
current contents of this Index.

The default implementation uses `prefetch` underneath, and returns
the results in a `CounterGather` object. However, alternate
implementations need only return an object that meets the
public `CounterGather` interface, of course.
"""
# build a flat query
prefetch_query = copy.copy(query)
prefetch_query.minhash = prefetch_query.minhash.flatten()

# find all matches and construct a CounterGather object.
counter = CounterGather(prefetch_query.minhash)
for result in self.prefetch(prefetch_query, threshold_bp, **kwargs):
counter.add(result.signature, result.location)

# tada!
return counter

@abstractmethod
def select(self, ksize=None, moltype=None, scaled=None, num=None,
abund=None, containment=None):
Expand Down Expand Up @@ -431,135 +453,151 @@ def select(self, **kwargs):
traverse_yield_all=self.traverse_yield_all)


class CounterGatherIndex(Index):
def __init__(self, query):
self.query = query
self.scaled = query.minhash.scaled
class CounterGather:
"""
Track and summarize matches for efficient 'gather' protocol. This
could be used downstream of prefetch (for example).

The public interface is `peek(...)` and `consume(...)` only.
"""
def __init__(self, query_mh):
if not query_mh.scaled:
raise ValueError('gather requires scaled signatures')

# track query
self.orig_query_mh = copy.copy(query_mh).flatten()
self.scaled = query_mh.scaled

# track matching signatures & their locations
self.siglist = []
self.locations = []

# ...and overlaps with query
self.counter = Counter()

def insert(self, ss, location=None):
i = len(self.siglist)
self.siglist.append(ss)
self.locations.append(location)
# cannot add matches once query has started.
self.query_started = 0

def add(self, ss, location=None, require_overlap=True):
"Add this signature in as a potential match."
if self.query_started:
raise ValueError("cannot add more signatures to counter after peek/consume")

# upon insertion, count & track overlap with the specific query.
self.scaled = max(self.scaled, ss.minhash.scaled)
self.counter[i] = self.query.minhash.count_common(ss.minhash, True)
overlap = self.orig_query_mh.count_common(ss.minhash, True)
if overlap:
i = len(self.siglist)

self.counter[i] = overlap
self.siglist.append(ss)
self.locations.append(location)

# note: scaled will be max of all matches.
self.downsample(ss.minhash.scaled)
elif require_overlap:
raise ValueError("no overlap between query and signature!?")

def downsample(self, scaled):
"Track highest scaled across all possible matches."
if scaled > self.scaled:
self.scaled = scaled

def calc_threshold(self, threshold_bp, scaled, query_size):
# CTB: this code doesn't need to be in this class.
threshold = 0.0
n_threshold_hashes = 0

def gather(self, query, threshold_bp=0, **kwargs):
"Perform compositional analysis of the query using the gather algorithm"
# CTB: switch over to JaccardSearch objects?
if threshold_bp:
# if we have a threshold_bp of N, then that amounts to N/scaled
# hashes:
n_threshold_hashes = float(threshold_bp) / scaled

if not query.minhash: # empty query? quit.
return []
# that then requires the following containment:
threshold = n_threshold_hashes / query_size

# bad query?
scaled = query.minhash.scaled
if not scaled:
raise ValueError('gather requires scaled signatures')
return threshold, n_threshold_hashes

if scaled == self.scaled:
query_mh = query.minhash
elif scaled < self.scaled:
query_mh = query.minhash.downsample(scaled=self.scaled)
scaled = self.scaled
else: # query scaled > self.scaled, should never happen
assert 0
def peek(self, cur_query_mh, threshold_bp=0):
"Get next 'gather' result for this database, w/o changing counters."
self.query_started = 1
scaled = cur_query_mh.scaled

# empty? nothing to search.
counter = self.counter
siglist = self.siglist
if not (counter and siglist):
if not counter:
return []

threshold = 0.0
n_threshold_hashes = 0
siglist = self.siglist
assert siglist

# are we setting a threshold?
if threshold_bp:
# if we have a threshold_bp of N, then that amounts to N/scaled
# hashes:
n_threshold_hashes = float(threshold_bp) / scaled
self.downsample(scaled)
scaled = self.scaled
cur_query_mh = cur_query_mh.downsample(scaled=scaled)

# that then requires the following containment:
threshold = n_threshold_hashes / len(query_mh)
if not cur_query_mh: # empty query? quit.
return []

# is it too high to ever match? if so, exit.
if threshold > 1.0:
return []
if cur_query_mh.contained_by(self.orig_query_mh, downsample=True) < 1:
raise ValueError("current query not a subset of original query")

# Decompose query into matching signatures using a greedy approach
# (gather)
match_size = n_threshold_hashes
# are we setting a threshold?
threshold, n_threshold_hashes = self.calc_threshold(threshold_bp,
scaled,
len(cur_query_mh))
# is it too high to ever match? if so, exit.
if threshold > 1.0:
return []

# Find the best match -
most_common = counter.most_common()
dataset_id, size = most_common.pop(0)
dataset_id, match_size = most_common[0]

# fail threshold!
if size < n_threshold_hashes:
# below threshold? no match!
if match_size < n_threshold_hashes:
return []

match_size = size
## at this point, we must have a legitimate match above threshold!

# pull match and location.
match = siglist[dataset_id]
location = self.locations[dataset_id]

# remove from counter for next round of gather
del counter[dataset_id]
# calculate containment
cont = cur_query_mh.contained_by(match.minhash, downsample=True)
assert cont
assert cont >= threshold

# pull containment
cont = query_mh.contained_by(match.minhash, downsample=True)
result = None
if cont and cont >= threshold:
result = IndexSearchResult(cont, match, location)

# calculate intersection of this "best match" with query, for removal.
# @CTB note flatten
# calculate intersection of this "best match" with query.
match_mh = match.minhash.downsample(scaled=scaled).flatten()
intersect_mh = query_mh.intersection(match_mh)

# Prepare counter for finding the next match by decrementing
# all hashes found in the current match in other datasets;
# remove empty datasets from counter, too.
for (dataset_id, _) in most_common:
remaining_sig = siglist[dataset_id]
intersect_count = remaining_sig.minhash.count_common(intersect_mh,
downsample=True)
counter[dataset_id] -= intersect_count
if counter[dataset_id] == 0:
del counter[dataset_id]

if result:
return [result]
return []
intersect_mh = cur_query_mh.intersection(match_mh)
location = self.locations[dataset_id]

def signatures(self):
raise NotImplementedError
# build result & return intersection
return (IndexSearchResult(cont, match, location), intersect_mh)

def signatures_with_location(self):
raise NotImplementedError
def consume(self, intersect_mh):
"Remove the given hashes from this counter."
self.query_started = 1

def prefetch(self, *args, **kwargs):
raise NotImplementedError
if not intersect_mh:
return

@classmethod
def load(self, *args):
raise NotImplementedError

def save(self, *args):
raise NotImplementedError

def find(self, search_fn, *args, **kwargs):
raise NotImplementedError
siglist = self.siglist
counter = self.counter

def search(self, query, *args, **kwargs):
raise NotImplementedError
most_common = counter.most_common()

def select(self, *args, **kwargs):
raise NotImplementedError
# Prepare counter for finding the next match by decrementing
# all hashes found in the current match in other datasets;
# remove empty datasets from counter, too.
for (dataset_id, _) in most_common:
# CTB: note, remaining_mh may not be at correct scaled here.
remaining_mh = siglist[dataset_id].minhash
intersect_count = intersect_mh.count_common(remaining_mh,
downsample=True)
if intersect_count:
counter[dataset_id] -= intersect_count
if counter[dataset_id] == 0:
del counter[dataset_id]


class MultiIndex(Index):
Expand Down
Loading