From 688fdfd84e3aba2339b27d9095e38ee7da8320f9 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Fri, 2 Apr 2021 07:08:57 -0700 Subject: [PATCH] [MRG] refactor & clean up database loading around MultiIndex class (#1406) * add an IndexOfIndexes class * rename to MultiIndex * switch to using MultiIndex for loading from a directory * some more MultiIndex tests * add test of MultiIndex.signatures * add docstring for MultiIndex * stop special-casing SIGLISTs * fix test to match more informative error message * switch to using LinearIndex.load for stdin, too * add __len__ to MultiIndex * add check_csv to check for appropriate filename loading info * add comment * fix databases load * more tests needed * add tests for incompatible signatures * add filter to LinearIndex and MultiIndex * clean up sourmash_args some more * shift loading over to Index classes * refactor, fix tests * switch to a list of loader functions * comments, docstrings, and tests passing * update to use f strings throughout sourmash_args.py * add docstrings * update comments * remove unnecessary changes * revert to original test * remove unneeded comment * clean up a bit * debugging update * better exception raising and capture for signature parsing * more specific error message * revert change in favor of creating new issue * add commentary => TODO * add tests for MultiIndex.load_from_directory; fix traverse code * switch lca summarize over to usig MultiIndex * switch to using MultiIndex in categorize * remove LoadSingleSignatures * test errors in lca database loading * remove unneeded categorize code * add testme info * verified that this was tested * remove testme comments * add tests for MultiIndex.load_from_file_list * Expand signature selection and compatibility checking in database loading code (#1420) * refactor select, add scaled/num/abund * fix scaled check for LCA database * add debug_literal * fix scaled check for SBT * fix LCA database ksize message & test * add 'containment' to 'select' * added 'is_database' flag for nicer UX * remove overly broad exception catching * document downsampling foo * fix file_list -> pathlist * fix typo --- src/sourmash/commands.py | 32 +- src/sourmash/index.py | 127 +++++++- src/sourmash/lca/command_summarize.py | 20 +- src/sourmash/lca/lca_db.py | 42 ++- src/sourmash/logging.py | 11 + src/sourmash/sbt.py | 62 +++- src/sourmash/signature.py | 2 +- src/sourmash/sourmash_args.py | 413 ++++++++++---------------- tests/test_api.py | 8 +- tests/test_index.py | 174 +++++++++++ tests/test_lca.py | 16 +- tests/test_sourmash.py | 61 ++-- 12 files changed, 616 insertions(+), 352 deletions(-) diff --git a/src/sourmash/commands.py b/src/sourmash/commands.py index 69b1432fc1..634324cac5 100644 --- a/src/sourmash/commands.py +++ b/src/sourmash/commands.py @@ -519,6 +519,8 @@ def search(args): def categorize(args): "Use a database to find the best match to many signatures." + from .index import MultiIndex + set_quiet(args.quiet) moltype = sourmash_args.calculate_moltype(args) @@ -533,14 +535,13 @@ def categorize(args): # load search database tree = load_sbt_index(args.sbt_name) - # load query filenames - inp_files = set(sourmash_args.traverse_find_sigs(args.queries)) - inp_files = inp_files - already_names - - notify('found {} files to query', len(inp_files)) - - loader = sourmash_args.LoadSingleSignatures(inp_files, - args.ksize, moltype) + # utility function to load & select relevant signatures. + def _yield_all_sigs(queries, ksize, moltype): + for filename in queries: + mi = MultiIndex.load_from_path(filename, False) + mi = mi.select(ksize=ksize, moltype=moltype) + for ss, loc in mi.signatures_with_location(): + yield ss, loc csv_w = None csv_fp = None @@ -548,9 +549,13 @@ def categorize(args): csv_fp = open(args.csv, 'w', newline='') csv_w = csv.writer(csv_fp) - for queryfile, query, query_moltype, query_ksize in loader: + for query, loc in _yield_all_sigs(args.queries, args.ksize, moltype): + # skip if we've already done signatures from this file. + if loc in already_names: + continue + notify('loaded query: {}... (k={}, {})', str(query)[:30], - query_ksize, query_moltype) + query.minhash.ksize, query.minhash.moltype) results = [] search_fn = SearchMinHashesFindBest().search @@ -575,14 +580,9 @@ def categorize(args): notify('for {}, no match found', query) if csv_w: - csv_w.writerow([queryfile, query, best_hit_query_name, + csv_w.writerow([loc, query, best_hit_query_name, best_hit_sim]) - if loader.skipped_ignore: - notify('skipped/ignore: {}', loader.skipped_ignore) - if loader.skipped_nosig: - notify('skipped/nosig: {}', loader.skipped_nosig) - if csv_fp: csv_fp.close() diff --git a/src/sourmash/index.py b/src/sourmash/index.py index 07e9c21b6a..30f93723e8 100644 --- a/src/sourmash/index.py +++ b/src/sourmash/index.py @@ -1,10 +1,14 @@ "An Abstract Base Class for collections of signatures." +import sourmash from abc import abstractmethod, ABC from collections import namedtuple +import os class Index(ABC): + is_database = False + @abstractmethod def signatures(self): "Return an iterator over all signatures in the Index object." @@ -122,8 +126,55 @@ def gather(self, query, *args, **kwargs): return results @abstractmethod - def select(self, ksize=None, moltype=None): - "" + def select(self, ksize=None, moltype=None, scaled=None, num=None, + abund=None, containment=None): + """Return Index containing only signatures that match requirements. + + Current arguments can be any or all of: + * ksize + * moltype + * scaled + * num + * containment + + 'select' will raise ValueError if the requirements are incompatible + with the Index subclass. + + 'select' may return an empty object or None if no matches can be + found. + """ + + +def select_signature(ss, ksize=None, moltype=None, scaled=0, num=0, + containment=False): + "Check that the given signature matches the specificed requirements." + # ksize match? + if ksize and ksize != ss.minhash.ksize: + return False + + # moltype match? + if moltype and moltype != ss.minhash.moltype: + return False + + # containment requires scaled; similarity does not. + if containment: + if not scaled: + raise ValueError("'containment' requires 'scaled' in Index.select'") + if not ss.minhash.scaled: + return False + + # 'scaled' and 'num' are incompatible + if scaled: + if ss.minhash.num: + return False + if num: + # note, here we check if 'num' is identical; this can be + # changed later. + if ss.minhash.scaled or num != ss.minhash.num: + return False + + return True + class LinearIndex(Index): "An Index for a collection of signatures. Can load from a .sig file." @@ -155,18 +206,17 @@ def load(cls, location): lidx = LinearIndex(si, filename=location) return lidx - def select(self, ksize=None, moltype=None): - def select_sigs(ss, ksize=ksize, moltype=moltype): - if (ksize is None or ss.minhash.ksize == ksize) and \ - (moltype is None or ss.minhash.moltype == moltype): - return True + def select(self, **kwargs): + """Return new LinearIndex containing only signatures that match req's. - return self.filter(select_sigs) + Does not raise ValueError, but may return an empty Index. + """ + # eliminate things from kwargs with None or zero value + kw = { k : v for (k, v) in kwargs.items() if v } - def filter(self, filter_fn): siglist = [] for ss in self._signatures: - if filter_fn(ss): + if select_signature(ss, **kwargs): siglist.append(ss) return LinearIndex(siglist, self.filename) @@ -193,6 +243,11 @@ def signatures(self): for ss in idx.signatures(): yield ss + def signatures_with_location(self): + for idx, loc in zip(self.index_list, self.source_list): + for ss in idx.signatures(): + yield ss, loc + def __len__(self): return sum([ len(idx) for idx in self.index_list ]) @@ -203,14 +258,62 @@ def insert(self, *args): def load(self, *args): raise NotImplementedError + @classmethod + def load_from_path(cls, pathname, force=False): + "Create a MultiIndex from a path (filename or directory)." + from .sourmash_args import traverse_find_sigs + if not os.path.exists(pathname): + raise ValueError(f"'{pathname}' must be a directory") + + index_list = [] + source_list = [] + for thisfile in traverse_find_sigs([pathname], yield_all_files=force): + try: + idx = LinearIndex.load(thisfile) + index_list.append(idx) + source_list.append(thisfile) + except (IOError, sourmash.exceptions.SourmashError): + if force: + continue # ignore error + else: + raise # continue past error! + + db = None + if index_list: + db = cls(index_list, source_list) + else: + raise ValueError(f"no signatures to load under directory '{pathname}'") + + return db + + @classmethod + def load_from_pathlist(cls, filename): + "Create a MultiIndex from all files listed in a text file." + from .sourmash_args import (load_pathlist_from_file, + load_file_as_index) + idx_list = [] + src_list = [] + + file_list = load_pathlist_from_file(filename) + for fname in file_list: + idx = load_file_as_index(fname) + src = fname + + idx_list.append(idx) + src_list.append(src) + + db = MultiIndex(idx_list, src_list) + return db + def save(self, *args): raise NotImplementedError - def select(self, ksize=None, moltype=None): + def select(self, **kwargs): + "Run 'select' on all indices within this MultiIndex." new_idx_list = [] new_src_list = [] for idx, src in zip(self.index_list, self.source_list): - idx = idx.select(ksize=ksize, moltype=moltype) + idx = idx.select(**kwargs) new_idx_list.append(idx) new_src_list.append(src) diff --git a/src/sourmash/lca/command_summarize.py b/src/sourmash/lca/command_summarize.py index 91e61df241..9823aa616e 100644 --- a/src/sourmash/lca/command_summarize.py +++ b/src/sourmash/lca/command_summarize.py @@ -10,6 +10,7 @@ from ..logging import notify, error, print_results, set_quiet, debug from . import lca_utils from .lca_utils import check_files_exist +from sourmash.index import MultiIndex DEFAULT_THRESHOLD=5 @@ -61,20 +62,15 @@ def load_singletons_and_count(filenames, ksize, scaled, ignore_abundance): total_count = 0 n = 0 - # in order to get the right reporting out of this function, we need - # to do our own traversal to expand the list of filenames, as opposed - # to using load_file_as_signatures(...) - filenames = sourmash_args.traverse_find_sigs(filenames) - filenames = list(filenames) - total_n = len(filenames) - - for query_filename in filenames: + for filename in filenames: n += 1 - for query_sig in sourmash_args.load_file_as_signatures(query_filename, - ksize=ksize): + mi = MultiIndex.load_from_path(filename) + mi = mi.select(ksize=ksize) + + for query_sig, query_filename in mi.signatures_with_location(): notify(u'\r\033[K', end=u'') - notify('... loading {} (file {} of {})', query_sig, n, + notify(f'... loading {query_sig} (file {n} of {total_n})', total_n, end='\r') total_count += 1 @@ -87,7 +83,7 @@ def load_singletons_and_count(filenames, ksize, scaled, ignore_abundance): yield query_filename, query_sig, hashvals notify(u'\r\033[K', end=u'') - notify('loaded {} signatures from {} files total.', total_count, n) + notify(f'loaded {total_count} signatures from {n} files total.') def count_signature(sig, scaled, hashvals): diff --git a/src/sourmash/lca/lca_db.py b/src/sourmash/lca/lca_db.py index 9c305c80b4..03e22dc8a2 100644 --- a/src/sourmash/lca/lca_db.py +++ b/src/sourmash/lca/lca_db.py @@ -1,5 +1,5 @@ "LCA database class and utilities." - +import os import json import gzip from collections import OrderedDict, defaultdict, Counter @@ -55,6 +55,8 @@ class LCA_Database(Index): `hashval_to_idx` is a dictionary from individual hash values to sets of `idx`. """ + is_database = True + def __init__(self, ksize, scaled, moltype='DNA'): self.ksize = int(ksize) self.scaled = int(scaled) @@ -169,24 +171,38 @@ def signatures(self): for v in self._signatures.values(): yield v - def select(self, ksize=None, moltype=None): - "Selector interface - make sure this database matches requirements." - ok = True + def select(self, ksize=None, moltype=None, num=0, scaled=0, + containment=False): + """Make sure this database matches the requested requirements. + + As with SBTs, queries with higher scaled values than the database + can still be used for containment search, but not for similarity + search. See SBT.select(...) for details, and _find_signatures for + implementation. + + Will always raise ValueError if a requirement cannot be met. + """ + if num: + raise ValueError("cannot use 'num' MinHashes to search LCA database") + + if scaled > self.scaled and not containment: + raise ValueError(f"cannot use scaled={scaled} on this database (scaled={self.scaled})") + if ksize is not None and self.ksize != ksize: - ok = False + raise ValueError(f"ksize on this database is {self.ksize}; this is different from requested ksize of {ksize}") if moltype is not None and moltype != self.moltype: - ok = False - - if ok: - return self + raise ValueError(f"moltype on this database is {self.moltype}; this is different from requested moltype of {moltype}") - raise ValueError("cannot select LCA on ksize {} / moltype {}".format(ksize, moltype)) + return self @classmethod def load(cls, db_name): "Load LCA_Database from a JSON file." from .lca_utils import taxlist, LineagePair + if not os.path.isfile(db_name): + raise ValueError(f"'{db_name}' is not a file and cannot be loaded as an LCA database") + xopen = open if db_name.endswith('.gz'): xopen = gzip.open @@ -464,12 +480,16 @@ def _find_signatures(self, minhash, threshold, containment=False, This is essentially a fast implementation of find that collects all the signatures with overlapping hash values. Note that similarity searches (containment=False) will not be returned in sorted order. + + As with SBTs, queries with higher scaled values than the database + can still be used for containment search, but not for similarity + search. See SBT.select(...) for details. """ # make sure we're looking at the same scaled value as database if self.scaled > minhash.scaled: minhash = minhash.downsample(scaled=self.scaled) elif self.scaled < minhash.scaled and not ignore_scaled: - # note that containment cannot be calculated w/o matching scaled. + # note that similarity cannot be calculated w/o matching scaled. raise ValueError("lca db scaled is {} vs query {}; must downsample".format(self.scaled, minhash.scaled)) query_mins = set(minhash.hashes) diff --git a/src/sourmash/logging.py b/src/sourmash/logging.py index 49c3dc26b3..2915c43f78 100644 --- a/src/sourmash/logging.py +++ b/src/sourmash/logging.py @@ -41,6 +41,17 @@ def debug(s, *args, **kwargs): sys.stderr.flush() +def debug_literal(s, *args, **kwargs): + "A debug logging function => stderr." + if _quiet or not _debug: + return + + print(u'\r\033[K', end=u'', file=sys.stderr) + print(s, file=sys.stderr, end=kwargs.get('end', u'\n')) + if kwargs.get('flush'): + sys.stderr.flush() + + def error(s, *args, **kwargs): "A simple error logging function => stderr." print(u'\r\033[K', end=u'', file=sys.stderr) diff --git a/src/sourmash/sbt.py b/src/sourmash/sbt.py index c499f50e2c..fa22507961 100644 --- a/src/sourmash/sbt.py +++ b/src/sourmash/sbt.py @@ -171,6 +171,7 @@ class SBT(Index): We use two dicts to store the tree structure: One for the internal nodes, and another for the leaves (datasets). """ + is_database = True def __init__(self, factory, *, d=2, storage=None, cache_size=None): self.factory = factory @@ -189,19 +190,60 @@ def signatures(self): for k in self.leaves(): yield k.data - def select(self, ksize=None, moltype=None): - first_sig = next(iter(self.signatures())) + def select(self, ksize=None, moltype=None, num=0, scaled=0, + containment=False): + """Make sure this database matches the requested requirements. - ok = True - if ksize is not None and first_sig.minhash.ksize != ksize: - ok = False - if moltype is not None and first_sig.minhash.moltype != moltype: - ok = False + Will always raise ValueError if a requirement cannot be met. - if ok: - return self + The only tricky bit here is around downsampling: if the scaled + value being requested is higher than the signatures in the + SBT, we can use the SBT for containment but not for + similarity. This is because: - raise ValueError("cannot select SBT on ksize {} / moltype {}".format(ksize, moltype)) + * if we are doing containment searches, the intermediate nodes + can still be used for calculating containment of signatures + with higher scaled values. This is because only hashes that match + in the higher range are used for containment scores. + * however, for similarity, _all_ hashes are used, and we cannot + implicitly downsample or necessarily estimate similarity if + the scaled values differ. + """ + # pull out a signature from this collection - + first_sig = next(iter(self.signatures())) + db_mh = first_sig.minhash + + # check ksize. + if ksize is not None and db_mh.ksize != ksize: + raise ValueError(f"search ksize {ksize} is different from database ksize {db_mh.ksize}") + + # check moltype. + if moltype is not None and db_mh.moltype != moltype: + raise ValueError(f"search moltype {moltype} is different from database moltype {db_mh.moltype}") + + # containment requires 'scaled'. + if containment: + if not scaled: + raise ValueError("'containment' requires 'scaled' in SBT.select'") + if not db_mh.scaled: + raise ValueError("cannot search this SBT for containment; signatures are not calculated with scaled") + + # 'num' and 'scaled' do not mix. + if num: + if not db_mh.num: + raise ValueError(f"this database was created with 'scaled' MinHash sketches, not 'num'") + if num != db_mh.num: + raise ValueError(f"num mismatch for SBT: num={num}, {db_mh.num}") + + if scaled: + if not db_mh.scaled: + raise ValueError(f"this database was created with 'num' MinHash sketches, not 'scaled'") + + # we can downsample SBTs for containment operations. + if scaled > db_mh.scaled and not containment: + raise ValueError(f"search scaled value {scaled} is less than database scaled value of {db_mh.scaled}") + + return self def new_node_pos(self, node): if not self._nodes: diff --git a/src/sourmash/signature.py b/src/sourmash/signature.py index 60bb2d5c90..e382e58311 100644 --- a/src/sourmash/signature.py +++ b/src/sourmash/signature.py @@ -252,7 +252,7 @@ def load_signatures( input_type = _detect_input_type(data) if input_type == SigInput.UNKNOWN: if do_raise: - raise Exception("Error in parsing signature; quitting. Cannot open file or invalid signature") + raise ValueError("Error in parsing signature; quitting. Cannot open file or invalid signature") return size = ffi.new("uintptr_t *") diff --git a/src/sourmash/sourmash_args.py b/src/sourmash/sourmash_args.py index b1c24aae12..919f1b9dd0 100644 --- a/src/sourmash/sourmash_args.py +++ b/src/sourmash/sourmash_args.py @@ -6,6 +6,7 @@ import argparse import itertools from enum import Enum +import traceback import screed @@ -14,7 +15,7 @@ import sourmash.exceptions from . import signature -from .logging import notify, error +from .logging import notify, error, debug_literal from .index import LinearIndex, MultiIndex from . import signature as sig @@ -24,7 +25,6 @@ import sourmash DEFAULT_LOAD_K = 31 -DEFAULT_N = 500 def get_moltype(sig, require=False): @@ -72,7 +72,7 @@ def load_query_signature(filename, ksize, select_moltype, select_md5=None): select_moltype=select_moltype) sl = list(sl) except (OSError, ValueError): - error("Cannot open file '{}'", filename) + error(f"Cannot open query file '{filename}'") sys.exit(-1) if len(sl) and select_md5: @@ -82,8 +82,7 @@ def load_query_signature(filename, ksize, select_moltype, select_md5=None): if sig_md5.startswith(select_md5.lower()): # make sure we pick only one -- if found_sig is not None: - error("Error! Multiple signatures start with md5 '{}'", - select_md5) + error(f"Error! Multiple signatures start with md5 '{select_md5}'") error("Please use a longer --md5 selector.") sys.exit(-1) else: @@ -96,150 +95,50 @@ def load_query_signature(filename, ksize, select_moltype, select_md5=None): if len(ksizes) == 1: ksize = ksizes.pop() sl = [ ss for ss in sl if ss.minhash.ksize == ksize ] - notify('select query k={} automatically.', ksize) + notify(f'select query k={ksize} automatically.') elif DEFAULT_LOAD_K in ksizes: sl = [ ss for ss in sl if ss.minhash.ksize == DEFAULT_LOAD_K ] - notify('selecting default query k={}.', DEFAULT_LOAD_K) + notify(f'selecting default query k={DEFAULT_LOAD_K}.') elif ksize: - notify('selecting specified query k={}', ksize) + notify(f'selecting specified query k={ksize}') if len(sl) != 1: - error('When loading query from "{}"', filename) - error('{} signatures matching ksize and molecule type;', len(sl)) + error(f"When loading query from '{filename}'", filename) + error(f'{len(sl)} signatures matching ksize and molecule type;') error('need exactly one. Specify --ksize or --dna, --rna, or --protein.') sys.exit(-1) return sl[0] -class LoadSingleSignatures(object): - def __init__(self, filelist, ksize=None, select_moltype=None, - ignore_files=set()): - self.filelist = filelist - self.ksize = ksize - self.select_moltype = select_moltype - self.ignore_files = ignore_files - - self.skipped_ignore = 0 - self.skipped_nosig = 0 - self.ksizes = set() - self.moltypes = set() - - def __iter__(self): - for filename in self.filelist: - if filename in self.ignore_files: - self.skipped_ignore += 1 - continue - - sl = signature.load_signatures(filename, - ksize=self.ksize, - select_moltype=self.select_moltype) - sl = list(sl) - if len(sl) == 0: - self.skipped_nosig += 1 - continue - - for query in sl: - query_moltype = get_moltype(query) - query_ksize = query.minhash.ksize - - self.ksizes.add(query_ksize) - self.moltypes.add(query_moltype) - - if len(self.ksizes) > 1 or len(self.moltypes) > 1: - raise ValueError('multiple k-mer sizes/molecule types present') - - for query in sl: - yield filename, query, query_moltype, query_ksize +def _check_suffix(filename, endings): + for ending in endings: + if filename.endswith(ending): + return True + return False def traverse_find_sigs(filenames, yield_all_files=False): + """Find all .sig and .sig.gz files in & beneath 'filenames'. + + By default, this function returns files with .sig and .sig.gz extensions. + If 'yield_all_files' is True, this will return _all_ files + (but not directories). + """ endings = ('.sig', '.sig.gz') for filename in filenames: + # check for files in filenames: if os.path.isfile(filename): - yield_me = False - if yield_all_files: - yield_me = True - continue - else: - for ending in endings: - if filename.endswith(ending): - yield_me = True - break - - if yield_me: + if yield_all_files or _check_suffix(filename, endings): yield filename - continue - # filename is a directory -- - dirname = filename - - for root, dirs, files in os.walk(dirname): - for name in files: - if name.endswith('.sig') or yield_all_files: + # filename is a directory -- traverse beneath! + elif os.path.isdir(filename): + for root, dirs, files in os.walk(filename): + for name in files: fullname = os.path.join(root, name) - yield fullname - - -def _check_signatures_are_compatible(query, subject): - # is one scaled, and the other not? cannot do search - if query.minhash.scaled and not subject.minhash.scaled or \ - not query.minhash.scaled and subject.minhash.scaled: - error("signature {} and {} are incompatible - cannot compare.", - query, subject) - if query.minhash.scaled: - error("{} was calculated with --scaled, {} was not.", - query, subject) - if subject.minhash.scaled: - error("{} was calculated with --scaled, {} was not.", - subject, query) - return 0 - - return 1 - - -def check_tree_is_compatible(treename, tree, query, is_similarity_query): - # get a minhash from the tree - leaf = next(iter(tree.leaves())) - tree_mh = leaf.data.minhash - - query_mh = query.minhash - - if tree_mh.ksize != query_mh.ksize: - error("ksize on tree '{}' is {};", treename, tree_mh.ksize) - error('this is different from query ksize of {}.', query_mh.ksize) - return 0 - - # is one scaled, and the other not? cannot do search. - if (tree_mh.scaled and not query_mh.scaled) or \ - (query_mh.scaled and not tree_mh.scaled): - error("for tree '{}', tree and query are incompatible for search.", - treename) - if tree_mh.scaled: - error("tree was calculated with scaled, query was not.") - else: - error("query was calculated with scaled, tree was not.") - return 0 - - # are the scaled values incompatible? cannot downsample tree for similarity - if tree_mh.scaled and tree_mh.scaled < query_mh.scaled and \ - is_similarity_query: - error("for tree '{}', scaled value is smaller than query.", treename) - error("tree scaled: {}; query scaled: {}. Cannot do similarity search.", - tree_mh.scaled, query_mh.scaled) - return 0 - - return 1 - - -def check_lca_db_is_compatible(filename, db, query): - query_mh = query.minhash - if db.ksize != query_mh.ksize: - error("ksize on db '{}' is {};", filename, db.ksize) - error('this is different from query ksize of {}.', query_mh.ksize) - return 0 - - return 1 + if yield_all_files or _check_suffix(fullname, endings): + yield fullname def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None): @@ -250,170 +149,155 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, cache_size=None) This is basically a user-focused wrapping of _load_databases. """ - query_ksize = query.minhash.ksize - query_moltype = get_moltype(query) + query_mh = query.minhash + + containment = True + if is_similarity_query: + containment = False - n_signatures = 0 - n_databases = 0 databases = [] for filename in filenames: - notify('loading from {}...', filename, end='\r') + notify(f'loading from {filename}...', end='\r') try: - db, dbtype = _load_database(filename, False, cache_size=cache_size) - except IOError as e: + db = _load_database(filename, False, cache_size=cache_size) + except ValueError as e: + # cannot load database! notify(str(e)) sys.exit(-1) - # are we collecting signatures from a directory/path? - if dbtype == DatabaseType.SBT: - if not check_tree_is_compatible(filename, db, query, - is_similarity_query): - sys.exit(-1) + try: + db = db.select(moltype=query_mh.moltype, + ksize=query_mh.ksize, + num=query_mh.num, + scaled=query_mh.scaled, + containment=containment) + except ValueError as exc: + # incompatible collection specified! + notify(f"ERROR: cannot use '{filename}' for this query.") + notify(str(exc)) + sys.exit(-1) - databases.append(db) - notify('loaded SBT {}', filename, end='\r') - n_databases += 1 + # 'select' returns nothing => all signatures filtered out. fail! + if not db: + notify(f"no compatible signatures found in '{filename}'") + sys.exit(-1) - # LCA - elif dbtype == DatabaseType.LCA: - if not check_lca_db_is_compatible(filename, db, query): - sys.exit(-1) + databases.append(db) - notify('loaded LCA {}', filename, end='\r') + # calc num loaded info. + n_signatures = 0 + n_databases = 0 + for db in databases: + if db.is_database: n_databases += 1 + else: + n_signatures += len(db) - databases.append(db) + notify(' '*79, end='\r') + if n_signatures and n_databases: + notify(f'loaded {n_signatures} signatures and {n_databases} databases total.') + elif n_signatures and not n_databases: + notify(f'loaded {n_signatures} signatures.') + elif n_databases and not n_signatures: + notify(f'loaded {n_databases} databases.') - # signature file - elif dbtype == DatabaseType.SIGLIST: - db = db.select(moltype=query_moltype, ksize=query_ksize) - siglist = db.signatures() - filter_fn = lambda s: _check_signatures_are_compatible(query, s) - db = db.filter(filter_fn) + if databases: + print('') + else: + notify('** ERROR: no signatures or databases loaded?') + sys.exit(-1) - if not db: - notify(f"no compatible signatures found in '{filename}'") - sys.exit(-1) + return databases - databases.append(db) - notify(f'loaded {len(db)} signatures from {filename}', end='\r') - n_signatures += len(db) +def _load_stdin(filename, **kwargs): + "Load collection from .sig file streamed in via stdin" + db = None + if filename == '-': + db = LinearIndex.load(sys.stdin) - # unknown!? - else: - raise Exception("unknown dbtype {}".format(dbtype)) + return db - # END for loop +def _multiindex_load_from_pathlist(filename, **kwargs): + "Load collection from a list of signature/database files" + db = MultiIndex.load_from_pathlist(filename) - notify(' '*79, end='\r') - if n_signatures and n_databases: - notify('loaded {} signatures and {} databases total.', n_signatures, - n_databases) - elif n_signatures: - notify('loaded {} signatures.', n_signatures) - elif n_databases: - notify('loaded {} databases.', n_databases) - else: - notify('** ERROR: no signatures or databases loaded?') - sys.exit(-1) + return db - if databases: - print('') - return databases +def _multiindex_load_from_path(filename, **kwargs): + "Load collection from a directory." + traverse_yield_all = kwargs['traverse_yield_all'] + db = MultiIndex.load_from_path(filename, traverse_yield_all) + return db -class DatabaseType(Enum): - SIGLIST = 1 - SBT = 2 - LCA = 3 +def _load_sigfile(filename, **kwargs): + "Load collection from a signature JSON file" + try: + db = LinearIndex.load(filename) + except sourmash.exceptions.SourmashError as exc: + raise ValueError(exc) -def _load_database(filename, traverse_yield_all, *, cache_size=None): - """Load file as a database - list of signatures, LCA, SBT, etc. + return db - Return (db, dbtype), where dbtype is a DatabaseType enum. - This is an internal function used by other functions in sourmash_args. - """ - loaded = False - dbtype = None +def _load_sbt(filename, **kwargs): + "Load collection from an SBT." + cache_size = kwargs.get('cache_size') - # special case stdin - if not loaded and filename == '-': - db = LinearIndex.load(sys.stdin) - dbtype = DatabaseType.SIGLIST - loaded = True - - # load signatures from directory, using MultiIndex to preserve source. - if not loaded and os.path.isdir(filename): - index_list = [] - source_list = [] - for thisfile in traverse_find_sigs([filename], traverse_yield_all): - try: - idx = LinearIndex.load(thisfile) - index_list.append(idx) - source_list.append(thisfile) - except (IOError, sourmash.exceptions.SourmashError): - if traverse_yield_all: - continue - else: - raise + try: + db = load_sbt_index(filename, cache_size=cache_size) + except FileNotFoundError as exc: + raise ValueError(exc) - if index_list: - loaded=True - db = MultiIndex(index_list, source_list) - dbtype = DatabaseType.SIGLIST + return db - # load signatures from single signature file - if not loaded: - try: - with open(filename, 'rt') as fp: - db = LinearIndex.load(filename) - dbtype = DatabaseType.SIGLIST - loaded = True - except Exception as exc: - pass - # try load signatures from single file (list of signature paths) - # use MultiIndex to preserve source filenames. - if not loaded: - try: - idx_list = [] - src_list = [] +def _load_revindex(filename, **kwargs): + "Load collection from an LCA database/reverse index." + db, _, _ = load_single_database(filename) + return db - file_list = load_pathlist_from_file(filename) - for fname in file_list: - idx = load_file_as_index(fname) - src = fname - idx_list.append(idx) - src_list.append(src) +# all loader functions, in order. +_loader_functions = [ + ("load from stdin", _load_stdin), + ("load from directory", _multiindex_load_from_path), + ("load from sig file", _load_sigfile), + ("load from file list", _multiindex_load_from_pathlist), + ("load SBT", _load_sbt), + ("load revindex", _load_revindex), + ] - db = MultiIndex(idx_list, src_list) - dbtype = DatabaseType.SIGLIST - loaded = True - except Exception as exc: - pass - if not loaded: # try load as SBT - try: - db = load_sbt_index(filename, cache_size=cache_size) - loaded = True - dbtype = DatabaseType.SBT - except: - pass +def _load_database(filename, traverse_yield_all, *, cache_size=None): + """Load file as a database - list of signatures, LCA, SBT, etc. + + Return Index object. + + This is an internal function used by other functions in sourmash_args. + """ + loaded = False - if not loaded: # try load as LCA + # iterate through loader functions, trying them all. Catch ValueError + # but nothing else. + for (desc, load_fn) in _loader_functions: try: - db, _, _ = load_single_database(filename) + debug_literal(f"_load_databases: trying loader fn {desc}") + db = load_fn(filename, + traverse_yield_all=traverse_yield_all, + cache_size=cache_size) + except ValueError as exc: + debug_literal(f"_load_databases: FAIL on fn {desc}.") + debug_literal(traceback.format_exc()) + + if db: loaded = True - dbtype = DatabaseType.LCA - except: - pass + break # check to see if it's a FASTA/FASTQ record (i.e. screed loadable) # so we can provide a better error message to users. @@ -421,7 +305,7 @@ def _load_database(filename, traverse_yield_all, *, cache_size=None): successful_screed_load = False it = None try: - # CTB: could be kind of time consuming for big record, but at the + # CTB: could be kind of time consuming for a big record, but at the # moment screed doesn't expose format detection cleanly. with screed.open(filename) as it: record = next(iter(it)) @@ -430,12 +314,15 @@ def _load_database(filename, traverse_yield_all, *, cache_size=None): pass if successful_screed_load: - raise OSError("Error while reading signatures from '{}' - got sequences instead! Is this a FASTA/FASTQ file?".format(filename)) + raise ValueError(f"Error while reading signatures from '{filename}' - got sequences instead! Is this a FASTA/FASTQ file?") if not loaded: - raise OSError(f"Error while reading signatures from '{filename}'.") + raise ValueError(f"Error while reading signatures from '{filename}'.") + + if loaded: # this is a bit redundant but safe > sorry + assert db - return db, dbtype + return db def load_file_as_index(filename, yield_all_files=False): @@ -451,8 +338,7 @@ def load_file_as_index(filename, yield_all_files=False): this directory into an Index object. If yield_all_files=True, will attempt to load all files. """ - db, dbtype = _load_database(filename, yield_all_files) - return db + return _load_database(filename, yield_all_files) def load_file_as_signatures(filename, select_moltype=None, ksize=None, @@ -475,7 +361,7 @@ def load_file_as_signatures(filename, select_moltype=None, ksize=None, if progress: progress.notify(filename) - db, dbtype = _load_database(filename, yield_all_files) + db = _load_database(filename, yield_all_files) db = db.select(moltype=select_moltype, ksize=ksize) loader = db.signatures() @@ -490,10 +376,13 @@ def load_pathlist_from_file(filename): try: with open(filename, 'rt') as fp: file_list = [ x.rstrip('\r\n') for x in fp ] + + if not os.path.exists(file_list[0]): + raise ValueError("first element of list-of-files does not exist") except OSError: - raise ValueError("cannot open file '{}'".format(filename)) + raise ValueError(f"cannot open file '{filename}'") except UnicodeDecodeError: - raise ValueError("cannot parse file '{}' as list of filenames".format(filename)) + raise ValueError(f"cannot parse file '{filename}' as list of filenames") return file_list diff --git a/tests/test_api.py b/tests/test_api.py index a6c298d3c3..3291c62165 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -20,7 +20,7 @@ def test_sourmash_signature_api(c): @utils.in_tempdir def test_load_index_0_no_file(c): - with pytest.raises(OSError) as exc: + with pytest.raises(ValueError) as exc: idx = sourmash.load_file_as_index(c.output('does-not-exist')) assert 'Error while reading signatures from ' in str(exc.value) @@ -53,7 +53,9 @@ def test_load_fasta_as_signature(): # try loading a fasta file - should fail with informative exception testfile = utils.get_test_data('short.fa') - with pytest.raises(OSError) as e: + with pytest.raises(ValueError) as exc: idx = sourmash.load_file_as_index(testfile) - assert "Error while reading signatures from '{}' - got sequences instead! Is this a FASTA/FASTQ file?".format(testfile) in str(e) + print(exc.value) + + assert f"Error while reading signatures from '{testfile}' - got sequences instead! Is this a FASTA/FASTQ file?" in str(exc.value) diff --git a/tests/test_index.py b/tests/test_index.py index 1b9ae93402..9ddcfa8672 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -1,11 +1,17 @@ +""" +Tests for Index classes and subclasses. +""" +import pytest import glob import os import zipfile +import shutil import sourmash from sourmash import load_one_signature, SourmashSignature from sourmash.index import LinearIndex, MultiIndex from sourmash.sbt import SBT, GraphFactory, Leaf +from sourmash import sourmash_args import sourmash_tst_utils as utils @@ -502,3 +508,171 @@ def test_multi_index_signatures(): assert ss2 in siglist assert ss47 in siglist assert ss63 in siglist + + +def test_multi_index_load_from_path(): + dirname = utils.get_test_data('prot/protein') + mi = MultiIndex.load_from_path(dirname, force=False) + + sigs = list(mi.signatures()) + assert len(sigs) == 2 + + +def test_multi_index_load_from_path_2(): + # only load .sig files, currently; not the databases under that directory. + dirname = utils.get_test_data('prot') + mi = MultiIndex.load_from_path(dirname, force=False) + + print(mi.index_list) + print(mi.source_list) + + sigs = list(mi.signatures()) + assert len(sigs) == 6 + + +@utils.in_tempdir +def test_multi_index_load_from_path_3(c): + # check that force works ok on a directory + dirname = utils.get_test_data('prot') + + count = 0 + for root, dirs, files in os.walk(dirname): + for name in files: + print(f"at {name}") + fullname = os.path.join(root, name) + copyto = c.output(f"file{count}.sig") + shutil.copyfile(fullname, copyto) + count += 1 + + with pytest.raises(sourmash.exceptions.SourmashError): + mi = MultiIndex.load_from_path(c.location, force=False) + + +@utils.in_tempdir +def test_multi_index_load_from_path_3_yield_all_true(c): + # check that force works ok on a directory w/force=True + dirname = utils.get_test_data('prot') + + count = 0 + for root, dirs, files in os.walk(dirname): + for name in files: + print(f"at {name}") + fullname = os.path.join(root, name) + copyto = c.output(f"file{count}.something") + shutil.copyfile(fullname, copyto) + count += 1 + + mi = MultiIndex.load_from_path(c.location, force=True) + + print(mi.index_list) + print(mi.source_list) + + sigs = list(mi.signatures()) + assert len(sigs) == 6 + + +@utils.in_tempdir +def test_multi_index_load_from_path_3_yield_all_true_subdir(c): + # check that force works ok on subdirectories + dirname = utils.get_test_data('prot') + + target_dir = c.output("some_subdir") + os.mkdir(target_dir) + + count = 0 + for root, dirs, files in os.walk(dirname): + for name in files: + print(f"at {name}") + fullname = os.path.join(root, name) + copyto = os.path.join(target_dir, f"file{count}.something") + shutil.copyfile(fullname, copyto) + count += 1 + + mi = MultiIndex.load_from_path(c.location, force=True) + + print(mi.index_list) + print(mi.source_list) + + sigs = list(mi.signatures()) + assert len(sigs) == 6 + + +@utils.in_tempdir +def test_multi_index_load_from_path_3_sig_gz(c): + # check that we find .sig.gz files, too + dirname = utils.get_test_data('prot') + + count = 0 + for root, dirs, files in os.walk(dirname): + for name in files: + if not name.endswith('.sig'): # skip non .sig things + continue + print(f"at {name}") + fullname = os.path.join(root, name) + copyto = c.output(f"file{count}.sig.gz") + shutil.copyfile(fullname, copyto) + count += 1 + + mi = MultiIndex.load_from_path(c.location, force=False) + + print(mi.index_list) + print(mi.source_list) + + sigs = list(mi.signatures()) + assert len(sigs) == 6 + + +@utils.in_tempdir +def test_multi_index_load_from_path_3_check_traverse_fn(c): + # test the actual traverse function... eventually this test can be + # removed, probably, as we consolidate functionality and test MultiIndex + # better. + dirname = utils.get_test_data('prot') + files = list(sourmash_args.traverse_find_sigs([dirname])) + assert len(files) == 6, files + + files = list(sourmash_args.traverse_find_sigs([dirname], True)) + assert len(files) == 14, files + + +def test_multi_index_load_from_path_no_exist(): + dirname = utils.get_test_data('does-not-exist') + with pytest.raises(ValueError): + mi = MultiIndex.load_from_path(dirname, force=True) + + +def test_multi_index_load_from_pathlist_no_exist(): + dirname = utils.get_test_data('does-not-exist') + with pytest.raises(ValueError): + mi = MultiIndex.load_from_pathlist(dirname) + + +@utils.in_tempdir +def test_multi_index_load_from_pathlist_1(c): + dirname = utils.get_test_data('prot') + files = list(sourmash_args.traverse_find_sigs([dirname])) + assert len(files) == 6, files + + file_list = c.output('filelist.txt') + + with open(file_list, 'wt') as fp: + print("\n".join(files), file=fp) + mi = MultiIndex.load_from_pathlist(file_list) + + sigs = list(mi.signatures()) + assert len(sigs) == 6 + + +@utils.in_tempdir +def test_multi_index_load_from_pathlist_2(c): + dirname = utils.get_test_data('prot') + files = list(sourmash_args.traverse_find_sigs([dirname], True)) + assert len(files) == 14, files + + file_list = c.output('filelist.txt') + + with open(file_list, 'wt') as fp: + print("\n".join(files), file=fp) + + with pytest.raises(ValueError): + mi = MultiIndex.load_from_pathlist(file_list) diff --git a/tests/test_lca.py b/tests/test_lca.py index b4c0c8e990..65fec92350 100644 --- a/tests/test_lca.py +++ b/tests/test_lca.py @@ -394,6 +394,18 @@ def test_databases(): assert scaled == 10000 +def test_databases_load_fail_on_dir(): + filename1 = utils.get_test_data('lca') + with pytest.raises(ValueError) as exc: + dblist, ksize, scaled = lca_utils.load_databases([filename1]) + + +def test_databases_load_fail_on_not_exist(): + filename1 = utils.get_test_data('does-not-exist') + with pytest.raises(ValueError) as exc: + dblist, ksize, scaled = lca_utils.load_databases([filename1]) + + def test_db_repr(): filename = utils.get_test_data('lca/delmont-1.lca.json') db, ksize, scaled = lca_utils.load_single_database(filename) @@ -1896,8 +1908,8 @@ def test_incompat_lca_db_ksize_2(c): err = c.last_result.err print(err) - assert "ksize on db 'test.lca.json' is 25;" in err - assert 'this is different from query ksize of 31.' in err + assert "ERROR: cannot use 'test.lca.json' for this query." in err + assert "ksize on this database is 25; this is different from requested ksize of 31" @utils.in_tempdir diff --git a/tests/test_sourmash.py b/tests/test_sourmash.py index a1f2b55f6e..deb1b667cb 100644 --- a/tests/test_sourmash.py +++ b/tests/test_sourmash.py @@ -716,7 +716,7 @@ def test_search_query_sig_does_not_exist(c): print(c.last_result.status, c.last_result.out, c.last_result.err) assert c.last_result.status == -1 - assert "Cannot open file 'short2.fa.sig'" in c.last_result.err + assert "Cannot open query file 'short2.fa.sig'" in c.last_result.err assert len(c.last_result.err.split('\n\r')) < 5 @@ -1851,12 +1851,14 @@ def test_search_incompatible(c): assert c.last_result.status != 0 print(c.last_result.out) print(c.last_result.err) - assert 'incompatible - cannot compare.' in c.last_result.err - assert 'was calculated with --scaled,' in c.last_result.err + + assert "no compatible signatures found in " in c.last_result.err @utils.in_tempdir def test_search_traverse_incompatible(c): + # build a directory with some signatures in it, search for compatible + # signatures. searchdir = c.output('searchme') os.mkdir(searchdir) @@ -1866,10 +1868,7 @@ def test_search_traverse_incompatible(c): shutil.copyfile(scaled_sig, c.output('searchme/scaled.sig')) c.run_sourmash("search", scaled_sig, c.output('searchme')) - print(c.last_result.out) - print(c.last_result.err) - assert 'incompatible - cannot compare.' in c.last_result.err - assert 'was calculated with --scaled,' in c.last_result.err + assert '100.0% NC_009665.1 Shewanella baltica OS185, complete genome' in c.last_result.out # explanation: you cannot downsample a scaled SBT to match a scaled @@ -1896,8 +1895,11 @@ def test_search_metagenome_downsample(): in_directory=location, fail_ok=True) assert status == -1 - assert "for tree 'gcf_all', scaled value is smaller than query." in err - assert 'tree scaled: 10000; query scaled: 100000. Cannot do similarity search.' in err + print(out) + print(err) + + assert "ERROR: cannot use 'gcf_all' for this query." in err + assert "search scaled value 100000 is less than database scaled value of 10000" in err def test_search_metagenome_downsample_containment(): @@ -2055,7 +2057,11 @@ def test_do_sourmash_sbt_search_wrong_ksize(): fail_ok=True) assert status == -1 - assert 'this is different from' in err + print(out) + print(err) + + assert "ERROR: cannot use 'zzz' for this query." in err + assert "search ksize 51 is different from database ksize 31" in err def test_do_sourmash_sbt_search_multiple(): @@ -2170,7 +2176,10 @@ def test_do_sourmash_sbt_search_downsample_2(): '--threshold=0.01'], in_directory=location, fail_ok=True) assert status == -1 - assert 'Cannot do similarity search.' in err + print(out) + print(err) + assert "ERROR: cannot use 'foo' for this query." in err + assert "search scaled value 100000 is less than database scaled value of 2000" in err def test_do_sourmash_index_single(): @@ -2465,7 +2474,10 @@ def test_do_sourmash_sbt_search_scaled_vs_num_1(): fail_ok=True) assert status == -1 - assert 'tree and query are incompatible for search' in err + print(out) + print(err) + assert "ERROR: cannot use '" in err + assert "this database was created with 'num' MinHash sketches, not 'scaled'" in err def test_do_sourmash_sbt_search_scaled_vs_num_2(): @@ -2497,7 +2509,10 @@ def test_do_sourmash_sbt_search_scaled_vs_num_2(): fail_ok=True) assert status == -1 - assert 'tree and query are incompatible for search' in err + print(out) + print(err) + assert "ERROR: cannot use '" in err + assert "this database was created with 'scaled' MinHash sketches, not 'num'" in err def test_do_sourmash_sbt_search_scaled_vs_num_3(): @@ -2522,7 +2537,9 @@ def test_do_sourmash_sbt_search_scaled_vs_num_3(): fail_ok=True) assert status == -1 - assert 'incompatible - cannot compare' in err + print(out) + print(err) + assert "no compatible signatures found in " in err def test_do_sourmash_sbt_search_scaled_vs_num_4(): @@ -2547,7 +2564,9 @@ def test_do_sourmash_sbt_search_scaled_vs_num_4(): ['search', sig_loc2, sig_loc], fail_ok=True) assert status == -1 - assert 'incompatible - cannot compare' in err + print(out) + print(err) + assert "no compatible signatures found in " in err def test_do_sourmash_check_search_vs_actual_similarity(): @@ -3604,8 +3623,7 @@ def test_gather_traverse_incompatible(c): c.run_sourmash("gather", scaled_sig, c.output('searchme')) print(c.last_result.out) print(c.last_result.err) - assert 'incompatible - cannot compare.' in c.last_result.err - assert 'was calculated with --scaled,' in c.last_result.err + assert "5.2 Mbp 100.0% 100.0% NC_009665.1 Shewanella baltica OS185,..." in c.last_result.out def test_gather_metagenome_output_unassigned(): @@ -3750,6 +3768,7 @@ def test_gather_query_downsample(): with utils.TempDirectory() as location: testdata_glob = utils.get_test_data('gather/GCF*.sig') testdata_sigs = glob.glob(testdata_glob) + print(testdata_sigs) query_sig = utils.get_test_data('GCF_000006945.2-s500.sig') @@ -4238,8 +4257,7 @@ def test_sbt_categorize_already_done_traverse(): def test_sbt_categorize_multiple_ksizes_moltypes(): - # 'categorize' should fail when there are multiple ksizes or moltypes - # present + # 'categorize' works fine with multiple moltypes/ksizes with utils.TempDirectory() as location: testdata1 = utils.get_test_data('genome-s10.fa.gz.sig') testdata2 = utils.get_test_data('genome-s11.fa.gz.sig') @@ -4255,10 +4273,7 @@ def test_sbt_categorize_multiple_ksizes_moltypes(): args = ['categorize', 'zzz', '.'] status, out, err = utils.runscript('sourmash', args, - in_directory=location, fail_ok=True) - - assert status != 0 - assert 'multiple k-mer sizes/molecule types present' in err + in_directory=location) @utils.in_tempdir