Skip to content

Commit

Permalink
[MRG] refactor & clean up database loading around MultiIndex class (#…
Browse files Browse the repository at this point in the history
…1406)

* add an IndexOfIndexes class

* rename to MultiIndex

* switch to using MultiIndex for loading from a directory

* some more MultiIndex tests

* add test of MultiIndex.signatures

* add docstring for MultiIndex

* stop special-casing SIGLISTs

* fix test to match more informative error message

* switch to using LinearIndex.load for stdin, too

* add __len__ to MultiIndex

* add check_csv to check for appropriate filename loading info

* add comment

* fix databases load

* more tests needed

* add tests for incompatible signatures

* add filter to LinearIndex and MultiIndex

* clean up sourmash_args some more

* shift loading over to Index classes

* refactor, fix tests

* switch to a list of loader functions

* comments, docstrings, and tests passing

* update to use f strings throughout sourmash_args.py

* add docstrings

* update comments

* remove unnecessary changes

* revert to original test

* remove unneeded comment

* clean up a bit

* debugging update

* better exception raising and capture for signature parsing

* more specific error message

* revert change in favor of creating new issue

* add commentary => TODO

* add tests for MultiIndex.load_from_directory; fix traverse code

* switch lca summarize over to usig MultiIndex

* switch to using MultiIndex in categorize

* remove LoadSingleSignatures

* test errors in lca database loading

* remove unneeded categorize code

* add testme info

* verified that this was tested

* remove testme comments

* add tests for MultiIndex.load_from_file_list

* Expand signature selection and compatibility checking in database loading code (#1420)

* refactor select, add scaled/num/abund
* fix scaled check for LCA database
* add debug_literal
* fix scaled check for SBT
* fix LCA database ksize message & test
* add 'containment' to 'select'
* added 'is_database' flag for nicer UX
* remove overly broad exception catching
* document downsampling foo

* fix file_list -> pathlist

* fix typo
  • Loading branch information
ctb authored Apr 2, 2021
1 parent ed3c809 commit 688fdfd
Show file tree
Hide file tree
Showing 12 changed files with 616 additions and 352 deletions.
32 changes: 16 additions & 16 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,8 @@ def search(args):

def categorize(args):
"Use a database to find the best match to many signatures."
from .index import MultiIndex

set_quiet(args.quiet)
moltype = sourmash_args.calculate_moltype(args)

Expand All @@ -533,24 +535,27 @@ def categorize(args):
# load search database
tree = load_sbt_index(args.sbt_name)

# load query filenames
inp_files = set(sourmash_args.traverse_find_sigs(args.queries))
inp_files = inp_files - already_names

notify('found {} files to query', len(inp_files))

loader = sourmash_args.LoadSingleSignatures(inp_files,
args.ksize, moltype)
# utility function to load & select relevant signatures.
def _yield_all_sigs(queries, ksize, moltype):
for filename in queries:
mi = MultiIndex.load_from_path(filename, False)
mi = mi.select(ksize=ksize, moltype=moltype)
for ss, loc in mi.signatures_with_location():
yield ss, loc

csv_w = None
csv_fp = None
if args.csv:
csv_fp = open(args.csv, 'w', newline='')
csv_w = csv.writer(csv_fp)

for queryfile, query, query_moltype, query_ksize in loader:
for query, loc in _yield_all_sigs(args.queries, args.ksize, moltype):
# skip if we've already done signatures from this file.
if loc in already_names:
continue

notify('loaded query: {}... (k={}, {})', str(query)[:30],
query_ksize, query_moltype)
query.minhash.ksize, query.minhash.moltype)

results = []
search_fn = SearchMinHashesFindBest().search
Expand All @@ -575,14 +580,9 @@ def categorize(args):
notify('for {}, no match found', query)

if csv_w:
csv_w.writerow([queryfile, query, best_hit_query_name,
csv_w.writerow([loc, query, best_hit_query_name,
best_hit_sim])

if loader.skipped_ignore:
notify('skipped/ignore: {}', loader.skipped_ignore)
if loader.skipped_nosig:
notify('skipped/nosig: {}', loader.skipped_nosig)

if csv_fp:
csv_fp.close()

Expand Down
127 changes: 115 additions & 12 deletions src/sourmash/index.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"An Abstract Base Class for collections of signatures."

import sourmash
from abc import abstractmethod, ABC
from collections import namedtuple
import os


class Index(ABC):
is_database = False

@abstractmethod
def signatures(self):
"Return an iterator over all signatures in the Index object."
Expand Down Expand Up @@ -122,8 +126,55 @@ def gather(self, query, *args, **kwargs):
return results

@abstractmethod
def select(self, ksize=None, moltype=None):
""
def select(self, ksize=None, moltype=None, scaled=None, num=None,
abund=None, containment=None):
"""Return Index containing only signatures that match requirements.
Current arguments can be any or all of:
* ksize
* moltype
* scaled
* num
* containment
'select' will raise ValueError if the requirements are incompatible
with the Index subclass.
'select' may return an empty object or None if no matches can be
found.
"""


def select_signature(ss, ksize=None, moltype=None, scaled=0, num=0,
containment=False):
"Check that the given signature matches the specificed requirements."
# ksize match?
if ksize and ksize != ss.minhash.ksize:
return False

# moltype match?
if moltype and moltype != ss.minhash.moltype:
return False

# containment requires scaled; similarity does not.
if containment:
if not scaled:
raise ValueError("'containment' requires 'scaled' in Index.select'")
if not ss.minhash.scaled:
return False

# 'scaled' and 'num' are incompatible
if scaled:
if ss.minhash.num:
return False
if num:
# note, here we check if 'num' is identical; this can be
# changed later.
if ss.minhash.scaled or num != ss.minhash.num:
return False

return True


class LinearIndex(Index):
"An Index for a collection of signatures. Can load from a .sig file."
Expand Down Expand Up @@ -155,18 +206,17 @@ def load(cls, location):
lidx = LinearIndex(si, filename=location)
return lidx

def select(self, ksize=None, moltype=None):
def select_sigs(ss, ksize=ksize, moltype=moltype):
if (ksize is None or ss.minhash.ksize == ksize) and \
(moltype is None or ss.minhash.moltype == moltype):
return True
def select(self, **kwargs):
"""Return new LinearIndex containing only signatures that match req's.
return self.filter(select_sigs)
Does not raise ValueError, but may return an empty Index.
"""
# eliminate things from kwargs with None or zero value
kw = { k : v for (k, v) in kwargs.items() if v }

def filter(self, filter_fn):
siglist = []
for ss in self._signatures:
if filter_fn(ss):
if select_signature(ss, **kwargs):
siglist.append(ss)

return LinearIndex(siglist, self.filename)
Expand All @@ -193,6 +243,11 @@ def signatures(self):
for ss in idx.signatures():
yield ss

def signatures_with_location(self):
for idx, loc in zip(self.index_list, self.source_list):
for ss in idx.signatures():
yield ss, loc

def __len__(self):
return sum([ len(idx) for idx in self.index_list ])

Expand All @@ -203,14 +258,62 @@ def insert(self, *args):
def load(self, *args):
raise NotImplementedError

@classmethod
def load_from_path(cls, pathname, force=False):
"Create a MultiIndex from a path (filename or directory)."
from .sourmash_args import traverse_find_sigs
if not os.path.exists(pathname):
raise ValueError(f"'{pathname}' must be a directory")

index_list = []
source_list = []
for thisfile in traverse_find_sigs([pathname], yield_all_files=force):
try:
idx = LinearIndex.load(thisfile)
index_list.append(idx)
source_list.append(thisfile)
except (IOError, sourmash.exceptions.SourmashError):
if force:
continue # ignore error
else:
raise # continue past error!

db = None
if index_list:
db = cls(index_list, source_list)
else:
raise ValueError(f"no signatures to load under directory '{pathname}'")

return db

@classmethod
def load_from_pathlist(cls, filename):
"Create a MultiIndex from all files listed in a text file."
from .sourmash_args import (load_pathlist_from_file,
load_file_as_index)
idx_list = []
src_list = []

file_list = load_pathlist_from_file(filename)
for fname in file_list:
idx = load_file_as_index(fname)
src = fname

idx_list.append(idx)
src_list.append(src)

db = MultiIndex(idx_list, src_list)
return db

def save(self, *args):
raise NotImplementedError

def select(self, ksize=None, moltype=None):
def select(self, **kwargs):
"Run 'select' on all indices within this MultiIndex."
new_idx_list = []
new_src_list = []
for idx, src in zip(self.index_list, self.source_list):
idx = idx.select(ksize=ksize, moltype=moltype)
idx = idx.select(**kwargs)
new_idx_list.append(idx)
new_src_list.append(src)

Expand Down
20 changes: 8 additions & 12 deletions src/sourmash/lca/command_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..logging import notify, error, print_results, set_quiet, debug
from . import lca_utils
from .lca_utils import check_files_exist
from sourmash.index import MultiIndex


DEFAULT_THRESHOLD=5
Expand Down Expand Up @@ -61,20 +62,15 @@ def load_singletons_and_count(filenames, ksize, scaled, ignore_abundance):
total_count = 0
n = 0

# in order to get the right reporting out of this function, we need
# to do our own traversal to expand the list of filenames, as opposed
# to using load_file_as_signatures(...)
filenames = sourmash_args.traverse_find_sigs(filenames)
filenames = list(filenames)

total_n = len(filenames)

for query_filename in filenames:
for filename in filenames:
n += 1
for query_sig in sourmash_args.load_file_as_signatures(query_filename,
ksize=ksize):
mi = MultiIndex.load_from_path(filename)
mi = mi.select(ksize=ksize)

for query_sig, query_filename in mi.signatures_with_location():
notify(u'\r\033[K', end=u'')
notify('... loading {} (file {} of {})', query_sig, n,
notify(f'... loading {query_sig} (file {n} of {total_n})',
total_n, end='\r')
total_count += 1

Expand All @@ -87,7 +83,7 @@ def load_singletons_and_count(filenames, ksize, scaled, ignore_abundance):
yield query_filename, query_sig, hashvals

notify(u'\r\033[K', end=u'')
notify('loaded {} signatures from {} files total.', total_count, n)
notify(f'loaded {total_count} signatures from {n} files total.')


def count_signature(sig, scaled, hashvals):
Expand Down
42 changes: 31 additions & 11 deletions src/sourmash/lca/lca_db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"LCA database class and utilities."

import os
import json
import gzip
from collections import OrderedDict, defaultdict, Counter
Expand Down Expand Up @@ -55,6 +55,8 @@ class LCA_Database(Index):
`hashval_to_idx` is a dictionary from individual hash values to sets of
`idx`.
"""
is_database = True

def __init__(self, ksize, scaled, moltype='DNA'):
self.ksize = int(ksize)
self.scaled = int(scaled)
Expand Down Expand Up @@ -169,24 +171,38 @@ def signatures(self):
for v in self._signatures.values():
yield v

def select(self, ksize=None, moltype=None):
"Selector interface - make sure this database matches requirements."
ok = True
def select(self, ksize=None, moltype=None, num=0, scaled=0,
containment=False):
"""Make sure this database matches the requested requirements.
As with SBTs, queries with higher scaled values than the database
can still be used for containment search, but not for similarity
search. See SBT.select(...) for details, and _find_signatures for
implementation.
Will always raise ValueError if a requirement cannot be met.
"""
if num:
raise ValueError("cannot use 'num' MinHashes to search LCA database")

if scaled > self.scaled and not containment:
raise ValueError(f"cannot use scaled={scaled} on this database (scaled={self.scaled})")

if ksize is not None and self.ksize != ksize:
ok = False
raise ValueError(f"ksize on this database is {self.ksize}; this is different from requested ksize of {ksize}")
if moltype is not None and moltype != self.moltype:
ok = False

if ok:
return self
raise ValueError(f"moltype on this database is {self.moltype}; this is different from requested moltype of {moltype}")

raise ValueError("cannot select LCA on ksize {} / moltype {}".format(ksize, moltype))
return self

@classmethod
def load(cls, db_name):
"Load LCA_Database from a JSON file."
from .lca_utils import taxlist, LineagePair

if not os.path.isfile(db_name):
raise ValueError(f"'{db_name}' is not a file and cannot be loaded as an LCA database")

xopen = open
if db_name.endswith('.gz'):
xopen = gzip.open
Expand Down Expand Up @@ -464,12 +480,16 @@ def _find_signatures(self, minhash, threshold, containment=False,
This is essentially a fast implementation of find that collects all
the signatures with overlapping hash values. Note that similarity
searches (containment=False) will not be returned in sorted order.
As with SBTs, queries with higher scaled values than the database
can still be used for containment search, but not for similarity
search. See SBT.select(...) for details.
"""
# make sure we're looking at the same scaled value as database
if self.scaled > minhash.scaled:
minhash = minhash.downsample(scaled=self.scaled)
elif self.scaled < minhash.scaled and not ignore_scaled:
# note that containment cannot be calculated w/o matching scaled.
# note that similarity cannot be calculated w/o matching scaled.
raise ValueError("lca db scaled is {} vs query {}; must downsample".format(self.scaled, minhash.scaled))

query_mins = set(minhash.hashes)
Expand Down
11 changes: 11 additions & 0 deletions src/sourmash/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def debug(s, *args, **kwargs):
sys.stderr.flush()


def debug_literal(s, *args, **kwargs):
"A debug logging function => stderr."
if _quiet or not _debug:
return

print(u'\r\033[K', end=u'', file=sys.stderr)
print(s, file=sys.stderr, end=kwargs.get('end', u'\n'))
if kwargs.get('flush'):
sys.stderr.flush()


def error(s, *args, **kwargs):
"A simple error logging function => stderr."
print(u'\r\033[K', end=u'', file=sys.stderr)
Expand Down
Loading

0 comments on commit 688fdfd

Please sign in to comment.