diff --git a/doc/dev_plugins.md b/doc/dev_plugins.md new file mode 100644 index 0000000000..f8eddfee63 --- /dev/null +++ b/doc/dev_plugins.md @@ -0,0 +1,75 @@ +# sourmash plugins via Python entry points + +As of version 4.7.0, sourmash has experimental support for Python +plugins to load and save signatures in different ways (e.g. file +formats, RPC servers, databases, etc.). This support is provided via +the "entry points" mechanism supplied by +[`importlib.metadata`](https://docs.python.org/3/library/importlib.metadata.html) +and documented +[here](https://setuptools.pypa.io/en/latest/userguide/entry_point.html). + +```{note} +Note: The plugin API is _not_ finalized or subject to semantic +versioning just yet! Please subscribe to +[sourmash#1353](https://github.com/sourmash-bio/sourmash/issues/1353) +if you want to keep up to date on plugin support. +``` + +You can define entry points in the `pyproject.toml` file +like so: + +``` +[project.entry-points."sourmash.load_from"] +a_reader = "module_name:load_sketches" + +[project.entry-points."sourmash.save_to"] +a_writer = "module_name:SaveSignatures_WriteFile" +``` + +Here, `module_name` should be the name of the module to import. +`load_sketches` should be a function that takes a location along with +arbitrary keyword arguments and returns an `Index` object +(e.g. `LinearIndex` for a collection of in-memory +signatures). `SaveSignatures_WriteFile` should be a class that +subclasses `BaseSave_SignaturesToLocation` and implements its own +mechanisms of saving signatures. See the `sourmash.save_load` module +for saving and loading code already used in sourmash. + +Note that if the function or class has a `priority` attribute, this will +be used to determine the order in which the plugins are called. + +The `name` attribute of the plugin (`a_reader` and `a_writer` in +`pyproject.toml`, above) is only used in debugging. + +## Templates and examples + +If you want to create your own plug-in, you can start with the +[sourmash_plugin_template](https://github.com/sourmash-bio/sourmash_plugin_template) repo. + +Some (early stage) plugins are also available as examples: + +* [sourmash-bio/sourmash_plugin_load_urls](https://github.com/sourmash-bio/sourmash_plugin_load_urls) - load signatures and CSV manifests via [fsspec](https://filesystem-spec.readthedocs.io/). +* [sourmash-bio/sourmash_plugin_avro](https://github.com/sourmash-bio/sourmash_plugin_avro) - use [Apache Avro](https://avro.apache.org/) as a serialization format. + +## Debugging plugins + +`sourmash sig cat -o ` is a simple way to +invoke a `save_to` plugin. Use `-d` to turn on debugging output. + +`sourmash sig describe ` is a simple way to invoke +a `load_from` plugin. Use `-d` to turn on debugging output. + +## Semantic versioning and listing sourmash as a dependency + +Plugins should probably list sourmash as a dependency for installation. + +Once plugins are officially supported by sourmash, the plugin API will +be under [semantic versioning constraints](https://semver.org/). That +means that you should constrain plugins to depend on sourmash only up +to the next major version, e.g. sourmash v5. + +Specifically, we suggest placing something like: +``` +dependencies = ['sourmash>=4.8.0,<5'] +``` +in your `pyproject.toml` file. diff --git a/doc/developer.md b/doc/developer.md index 1e4d4dbbeb..b2a968de31 100644 --- a/doc/developer.md +++ b/doc/developer.md @@ -1,3 +1,7 @@ +```{contents} Contents +:depth: 3 +``` + # Developer information ## Development environment @@ -280,7 +284,7 @@ Some installation issues can be solved by simply removing the intermediate build make clean ``` -## Contents +## Additional developer-focused documents ```{toctree} :maxdepth: 2 @@ -289,4 +293,6 @@ release requirements storage release-notes/releases +dev_plugins ``` + diff --git a/src/sourmash/cli/sig/cat.py b/src/sourmash/cli/sig/cat.py index d251db0b4e..ed85932f5f 100644 --- a/src/sourmash/cli/sig/cat.py +++ b/src/sourmash/cli/sig/cat.py @@ -31,6 +31,10 @@ def subparser(subparsers): '-q', '--quiet', action='store_true', help='suppress non-error output' ) + subparser.add_argument( + '-d', '--debug', action='store_true', + help='provide debugging output' + ) subparser.add_argument( '-o', '--output', metavar='FILE', default='-', help='output signature to this file (default stdout)' diff --git a/src/sourmash/exceptions.py b/src/sourmash/exceptions.py index 4895c50947..b2f18c12d2 100644 --- a/src/sourmash/exceptions.py +++ b/src/sourmash/exceptions.py @@ -25,6 +25,11 @@ def __init__(self): SourmashError.__init__(self, "This index format is not supported in this version of sourmash") +class IndexNotLoaded(SourmashError): + def __init__(self, msg): + SourmashError.__init__(self, f"Cannot load sourmash index: {str(msg)}") + + def _make_error(error_name, base=SourmashError, code=None): class Exc(base): pass diff --git a/src/sourmash/plugins.py b/src/sourmash/plugins.py new file mode 100644 index 0000000000..2a786d6d24 --- /dev/null +++ b/src/sourmash/plugins.py @@ -0,0 +1,66 @@ +""" +Support for plugins to sourmash via importlib.metadata entrypoints. + +Plugin entry point names: +* 'sourmash.load_from' - Index class loading. +* 'sourmash.save_to' - Signature saving. +* 'sourmash.picklist_filters' - extended Picklist functionality. + +CTB TODO: + +* consider using something other than 'name' for loader fn name. Maybe __doc__? +* try implement picklist plugin? +""" + +DEFAULT_LOAD_FROM_PRIORITY = 99 +DEFAULT_SAVE_TO_PRIORITY = 99 + +from .logging import debug_literal + +# cover for older versions of Python that don't support selection on load +# (the 'group=' below). +from importlib.metadata import entry_points + +# load 'load_from' entry points. NOTE: this executes on import of this module. +try: + _plugin_load_from = entry_points(group='sourmash.load_from') +except TypeError: + from importlib_metadata import entry_points + _plugin_load_from = entry_points(group='sourmash.load_from') + +# load 'save_to' entry points as well. +_plugin_save_to = entry_points(group='sourmash.save_to') + + +def get_load_from_functions(): + "Load the 'load_from' plugins and yield tuples (priority, name, fn)." + debug_literal(f"load_from plugins: {_plugin_load_from}") + + # Load each plugin, + for plugin in _plugin_load_from: + loader_fn = plugin.load() + + # get 'priority' if it is available + priority = getattr(loader_fn, 'priority', DEFAULT_LOAD_FROM_PRIORITY) + + # retrieve name (which is specified by plugin?) + name = plugin.name + debug_literal(f"plugins.load_from_functions: got '{name}', priority={priority}") + yield priority, name, loader_fn + + +def get_save_to_functions(): + "Load the 'save_to' plugins and yield tuples (priority, fn)." + debug_literal(f"save_to plugins: {_plugin_save_to}") + + # Load each plugin, + for plugin in _plugin_save_to: + save_cls = plugin.load() + + # get 'priority' if it is available + priority = getattr(save_cls, 'priority', DEFAULT_SAVE_TO_PRIORITY) + + # retrieve name (which is specified by plugin?) + name = plugin.name + debug_literal(f"plugins.save_to_functions: got '{name}', priority={priority}") + yield priority, save_cls diff --git a/src/sourmash/save_load.py b/src/sourmash/save_load.py new file mode 100644 index 0000000000..bb842bd02e --- /dev/null +++ b/src/sourmash/save_load.py @@ -0,0 +1,530 @@ +""" +Index object/sigfile loading and signature saving code. + +This is the middleware code responsible for loading and saving signatures +in a variety of ways. + +--- + +Command-line functionality goes in sourmash_args.py. + +Low-level JSON reading/writing is in signature.py. + +Index objects are implemented in the index submodule. + +Public API: + +* load_file_as_index(filename, ...) -- load a sourmash.Index class +* SaveSignaturesToLocation(filename) - bulk signature output + +APIs for plugins to use: + +* class Base_SaveSignaturesToLocation - to implement a new output method. + +CTB TODO: +* consider replacing ValueError with IndexNotLoaded in the future. +""" +import sys +import os +import gzip +from io import StringIO +import zipfile +import itertools +import traceback + +import screed +import sourmash + +from . import plugins as sourmash_plugins +from .logging import notify, debug_literal +from .exceptions import IndexNotLoaded + +from .index.sqlite_index import load_sqlite_index, SqliteIndex +from .sbtmh import load_sbt_index +from .lca.lca_db import load_single_database +from . import signature as sigmod +from .index import (LinearIndex, ZipFileLinearIndex, MultiIndex) +from .manifest import CollectionManifest + + +def load_file_as_index(filename, *, yield_all_files=False): + """Load 'filename' as a database; generic database loader. + + If 'filename' contains an SBT or LCA indexed database, or a regular + Zip file, will return the appropriate objects. If a Zip file and + yield_all_files=True, will try to load all files within zip, not just + .sig files. + + If 'filename' is a JSON file containing one or more signatures, will + return an Index object containing those signatures. + + If 'filename' is a directory, will load *.sig underneath + this directory into an Index object. If yield_all_files=True, will + attempt to load all files. + """ + return _load_database(filename, yield_all_files) + + +def SaveSignaturesToLocation(location): + """ + Provides a context manager that saves signatures in various output formats. + + Usage: + + with SaveSignaturesToLocation(filename_or_location) as save_sigs: + save_sigs.add(sig_obj) + """ + save_list = itertools.chain(_save_classes, + sourmash_plugins.get_save_to_functions()) + for priority, cls in sorted(save_list, key=lambda x:x[0]): + debug_literal(f"trying to match save function {cls}, priority={priority}") + + if cls.matches(location): + debug_literal(f"{cls} is a match!") + return cls(location) + + raise Exception(f"cannot determine how to open location {location} for saving; this should never happen!?") + +### Implementation machinery for _load_databases + + +def _load_database(filename, traverse_yield_all, *, cache_size=None): + """Load file as a database - list of signatures, LCA, SBT, etc. + + Return Index object. + + This is an internal function used by other functions in sourmash_args. + """ + loaded = False + + # load plugins + plugin_fns = sourmash_plugins.get_load_from_functions() + + # aggregate with default load_from functions & sort by priority + load_from_functions = sorted(itertools.chain(_loader_functions, + plugin_fns)) + + # iterate through loader functions, sorted by priority; try them all. + # Catch ValueError & IndexNotLoaded but nothing else. + for (priority, desc, load_fn) in load_from_functions: + db = None + try: + debug_literal(f"_load_databases: trying loader fn - priority {priority} - '{desc}'") + db = load_fn(filename, + traverse_yield_all=traverse_yield_all, + cache_size=cache_size) + except (ValueError, IndexNotLoaded): + debug_literal(f"_load_databases: FAIL with ValueError: on fn {desc}.") + debug_literal(traceback.format_exc()) + debug_literal("(continuing past exception)") + + if db is not None: + loaded = True + debug_literal("_load_databases: success!") + break + + if loaded: + assert db is not None + return db + + raise ValueError(f"Error while reading signatures from '{filename}'.") + + +_loader_functions = [] +def add_loader(name, priority): + "decorator to add name/priority to _loader_functions" + def dec_priority(func): + _loader_functions.append((priority, name, func)) + return func + return dec_priority + + +@add_loader("load from stdin", 10) +def _load_stdin(filename, **kwargs): + "Load collection from .sig file streamed in via stdin" + db = None + if filename == '-': + # load as LinearIndex, then pass into MultiIndex to generate a + # manifest. + lidx = LinearIndex.load(sys.stdin, filename='-') + db = MultiIndex.load((lidx,), (None,), parent="-") + + return db + + +@add_loader("load from standalone manifest", 30) +def _load_standalone_manifest(filename, **kwargs): + from sourmash.index import StandaloneManifestIndex + + try: + idx = StandaloneManifestIndex.load(filename) + except gzip.BadGzipFile as exc: + raise IndexNotLoaded(exc) + + return idx + + +@add_loader("load from list of paths", 50) +def _multiindex_load_from_pathlist(filename, **kwargs): + "Load collection from a list of signature/database files" + db = MultiIndex.load_from_pathlist(filename) + + return db + + +@add_loader("load from path (file or directory)", 40) +def _multiindex_load_from_path(filename, **kwargs): + "Load collection from a directory." + traverse_yield_all = kwargs['traverse_yield_all'] + db = MultiIndex.load_from_path(filename, traverse_yield_all) + + return db + + +@add_loader("load SBT", 60) +def _load_sbt(filename, **kwargs): + "Load collection from an SBT." + cache_size = kwargs.get('cache_size') + + try: + db = load_sbt_index(filename, cache_size=cache_size) + except (FileNotFoundError, TypeError) as exc: + raise IndexNotLoaded(exc) + + return db + + +@add_loader("load revindex", 70) +def _load_revindex(filename, **kwargs): + "Load collection from an LCA database/reverse index." + db, _, _ = load_single_database(filename) + return db + + +@add_loader("load collection from sqlitedb", 20) +def _load_sqlite_db(filename, **kwargs): + return load_sqlite_index(filename) + + +@add_loader("load collection from zipfile", 80) +def _load_zipfile(filename, **kwargs): + "Load collection from a .zip file." + db = None + if filename.endswith('.zip'): + traverse_yield_all = kwargs['traverse_yield_all'] + try: + db = ZipFileLinearIndex.load(filename, + traverse_yield_all=traverse_yield_all) + except FileNotFoundError as exc: + # turn this into an IndexNotLoaded => proper exception handling by + # _load_database. + raise IndexNotLoaded(exc) + + return db + + +@add_loader("catch FASTA/FASTQ files and error", 1000) +def _error_on_fastaq(filename, **kwargs): + "This is a tail-end loader that checks for FASTA/FASTQ sequences => err." + success = False + try: + with screed.open(filename) as it: + _ = next(iter(it)) + + success = True + except: + pass + + if success: + raise Exception(f"Error while reading signatures from '{filename}' - got sequences instead! Is this a FASTA/FASTQ file?") + + +### Implementation machinery for SaveSignaturesToLocation + +class Base_SaveSignaturesToLocation: + "Base signature saving class. Track location (if any) and count." + def __init__(self, location): + self.location = location + self.count = 0 + + @classmethod + def matches(cls, location): + "returns True when this class should handle a specific location" + raise NotImplementedError + + def __repr__(self): + raise NotImplementedError + + def __len__(self): + return self.count + + def open(self): + pass + + def close(self): + pass + + def __enter__(self): + "provide context manager functionality" + self.open() + return self + + def __exit__(self, type, value, traceback): + "provide context manager functionality" + self.close() + + def add(self, ss): + self.count += 1 + + def add_many(self, sslist): + for ss in sslist: + self.add(ss) + + +def _get_signatures_from_rust(siglist): + # this function deals with a disconnect between the way Rust + # and Python handle signatures; Python expects one + # minhash (and hence one md5sum) per signature, while + # Rust supports multiple. For now, go through serializing + # and deserializing the signature! See issue #1167 for more. + json_str = sourmash.save_signatures(siglist) + for ss in sourmash.load_signatures(json_str): + yield ss + + +class SaveSignatures_NoOutput(Base_SaveSignaturesToLocation): + "Do not save signatures." + def __repr__(self): + return 'SaveSignatures_NoOutput()' + + @classmethod + def matches(cls, location): + return location is None + + def open(self): + pass + + def close(self): + pass + + +class SaveSignatures_Directory(Base_SaveSignaturesToLocation): + "Save signatures within a directory, using md5sum names." + def __init__(self, location): + super().__init__(location) + + def __repr__(self): + return f"SaveSignatures_Directory('{self.location}')" + + @classmethod + def matches(cls, location): + "anything ending in /" + if location: + return location.endswith('/') + + def close(self): + pass + + def open(self): + try: + os.mkdir(self.location) + except FileExistsError: + pass + except: + notify(f"ERROR: cannot create signature output directory '{self.location}'") + sys.exit(-1) + + def add(self, ss): + super().add(ss) + md5 = ss.md5sum() + + # don't overwrite even if duplicate md5sum + outname = os.path.join(self.location, f"{md5}.sig.gz") + if os.path.exists(outname): + i = 0 + while 1: + outname = os.path.join(self.location, f"{md5}_{i}.sig.gz") + if not os.path.exists(outname): + break + i += 1 + + with gzip.open(outname, "wb") as fp: + sigmod.save_signatures([ss], fp, compression=1) + + +class SaveSignatures_SqliteIndex(Base_SaveSignaturesToLocation): + "Save signatures within a directory, using md5sum names." + def __init__(self, location): + super().__init__(location) + self.location = location + self.idx = None + self.cursor = None + + @classmethod + def matches(cls, location): + "anything ending in .sqldb" + if location: + return location.endswith('.sqldb') + + def __repr__(self): + return f"SaveSignatures_SqliteIndex('{self.location}')" + + def close(self): + self.idx.commit() + self.cursor.execute('VACUUM') + self.idx.close() + + def open(self): + self.idx = SqliteIndex.create(self.location, append=True) + self.cursor = self.idx.cursor() + + def add(self, add_sig): + for ss in _get_signatures_from_rust([add_sig]): + super().add(ss) + self.idx.insert(ss, cursor=self.cursor, commit=False) + + # commit every 1000 signatures. + if self.count % 1000 == 0: + self.idx.commit() + + +class SaveSignatures_SigFile(Base_SaveSignaturesToLocation): + "Save signatures to a .sig JSON file." + def __init__(self, location): + super().__init__(location) + self.keep = [] + self.compress = 0 + if self.location.endswith('.gz'): + self.compress = 1 + + @classmethod + def matches(cls, location): + # match anything that is not None or "" + return bool(location) + + def __repr__(self): + return f"SaveSignatures_SigFile('{self.location}')" + + def open(self): + pass + + def close(self): + if self.location == '-': + sourmash.save_signatures(self.keep, sys.stdout) + else: + # text mode? encode in utf-8 + mode = "w" + encoding = 'utf-8' + + # compressed? bytes & binary. + if self.compress: + encoding = None + mode = "wb" + + with open(self.location, mode, encoding=encoding) as fp: + sourmash.save_signatures(self.keep, fp, + compression=self.compress) + + def add(self, ss): + super().add(ss) + self.keep.append(ss) + + +class SaveSignatures_ZipFile(Base_SaveSignaturesToLocation): + "Save compressed signatures in an uncompressed Zip file." + def __init__(self, location): + super().__init__(location) + self.storage = None + + @classmethod + def matches(cls, location): + "anything ending in .zip" + if location: + return location.endswith('.zip') + + def __repr__(self): + return f"SaveSignatures_ZipFile('{self.location}')" + + def close(self): + # finish constructing manifest object & save + manifest = CollectionManifest(self.manifest_rows) + manifest_name = "SOURMASH-MANIFEST.csv" + + manifest_fp = StringIO() + manifest.write_to_csv(manifest_fp, write_header=True) + manifest_data = manifest_fp.getvalue().encode("utf-8") + + self.storage.save(manifest_name, manifest_data, overwrite=True, + compress=True) + self.storage.flush() + self.storage.close() + + def open(self): + from .sbt_storage import ZipStorage + + do_create = True + if os.path.exists(self.location): + do_create = False + + storage = None + try: + storage = ZipStorage(self.location, mode="w") + except zipfile.BadZipFile: + pass + + if storage is None: + raise ValueError(f"File '{self.location}' cannot be opened as a zip file.") + + if not storage.subdir: + storage.subdir = 'signatures' + + # now, try to load manifest + try: + manifest_data = storage.load('SOURMASH-MANIFEST.csv') + except (FileNotFoundError, KeyError): + # if file already exists must have manifest... + if not do_create: + raise ValueError(f"Cannot add to existing zipfile '{self.location}' without a manifest") + self.manifest_rows = [] + else: + # success! decode manifest_data, create manifest rows => append. + manifest_data = manifest_data.decode('utf-8') + manifest_fp = StringIO(manifest_data) + manifest = CollectionManifest.load_from_csv(manifest_fp) + self.manifest_rows = list(manifest._select()) + + self.storage = storage + + def _exists(self, name): + try: + self.storage.load(name) + return True + except KeyError: + return False + + def add(self, add_sig): + if not self.storage: + raise ValueError("this output is not open") + + for ss in _get_signatures_from_rust([add_sig]): + buf = sigmod.save_signatures([ss], compression=1) + md5 = ss.md5sum() + + storage = self.storage + path = f'{storage.subdir}/{md5}.sig.gz' + location = storage.save(path, buf) + + # update manifest + row = CollectionManifest.make_manifest_row(ss, location, + include_signature=False) + self.manifest_rows.append(row) + super().add(ss) + + +_save_classes = [ + (10, SaveSignatures_NoOutput), + (20, SaveSignatures_Directory), + (30, SaveSignatures_ZipFile), + (40, SaveSignatures_SqliteIndex), + (1000, SaveSignatures_SigFile), +] diff --git a/src/sourmash/sig/__main__.py b/src/sourmash/sig/__main__.py index 48be17d9ad..7a757aa444 100644 --- a/src/sourmash/sig/__main__.py +++ b/src/sourmash/sig/__main__.py @@ -83,7 +83,7 @@ def cat(args): """ concatenate all signatures into one file. """ - set_quiet(args.quiet) + set_quiet(args.quiet, args.debug) moltype = sourmash_args.calculate_moltype(args) picklist = sourmash_args.load_picklist(args) pattern_search = sourmash_args.load_include_exclude_db_patterns(args) diff --git a/src/sourmash/sourmash_args.py b/src/sourmash/sourmash_args.py index 3a465b18c1..01fd86460c 100644 --- a/src/sourmash/sourmash_args.py +++ b/src/sourmash/sourmash_args.py @@ -22,45 +22,34 @@ * load_query_signature(filename, ...) -- load a single signature for query * traverse_find_sigs(filenames, ...) -- find all .sig and .sig.gz files * load_dbs_and_sigs(filenames, query, ...) -- load databases & signatures -* load_file_as_index(filename, ...) -- load a sourmash.Index class -* load_file_as_signatures(filename, ...) -- load a list of signatures * load_pathlist_from_file(filename) -- load a list of paths from a file * load_many_signatures(locations) -- load many signatures from many files * get_manifest(idx) -- retrieve or build a manifest from an Index * class SignatureLoadingProgress - signature loading progress bar +* load_file_as_signatures(filename, ...) -- load a list of signatures signature and file output functionality: -* SaveSignaturesToLocation(filename) - bulk signature output * class FileOutput - file output context manager that deals w/stdout well * class FileOutputCSV - file output context manager for CSV files """ import sys import os import csv -from enum import Enum -import traceback import gzip -from io import StringIO, TextIOWrapper +from io import TextIOWrapper import re import zipfile import contextlib - -import screed -import sourmash - -from sourmash.sbtmh import load_sbt_index -from sourmash.lca.lca_db import load_single_database -import sourmash.exceptions +import argparse from .logging import notify, error, debug_literal -from .index import (LinearIndex, ZipFileLinearIndex, MultiIndex) -from .index.sqlite_index import load_sqlite_index, SqliteIndex -from . import signature as sigmod +from .index import LinearIndex from .picklist import SignaturePicklist, PickStyle from .manifest import CollectionManifest -import argparse +from .save_load import (SaveSignaturesToLocation, load_file_as_index, + _load_database) DEFAULT_LOAD_K = 31 @@ -70,7 +59,7 @@ def check_scaled_bounds(arg): f = float(arg) if f < 0: - raise argparse.ArgumentTypeError(f"ERROR: scaled value must be positive") + raise argparse.ArgumentTypeError("ERROR: scaled value must be positive") if f < 100: notify('WARNING: scaled value should be >= 100. Continuing anyway.') if f > 1e6: @@ -82,7 +71,7 @@ def check_num_bounds(arg): f = int(arg) if f < 0: - raise argparse.ArgumentTypeError(f"ERROR: num value must be positive") + raise argparse.ArgumentTypeError("ERROR: num value must be positive") if f < 50: notify('WARNING: num value should be >= 50. Continuing anyway.') if f > 50000: @@ -352,211 +341,6 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *, return databases -def _load_stdin(filename, **kwargs): - "Load collection from .sig file streamed in via stdin" - db = None - if filename == '-': - # load as LinearIndex, then pass into MultiIndex to generate a - # manifest. - lidx = LinearIndex.load(sys.stdin, filename='-') - db = MultiIndex.load((lidx,), (None,), parent="-") - - return db - - -def _load_standalone_manifest(filename, **kwargs): - from sourmash.index import StandaloneManifestIndex - - try: - idx = StandaloneManifestIndex.load(filename) - except gzip.BadGzipFile as exc: - raise ValueError(exc) - - return idx - - -def _multiindex_load_from_pathlist(filename, **kwargs): - "Load collection from a list of signature/database files" - db = MultiIndex.load_from_pathlist(filename) - - return db - - -def _multiindex_load_from_path(filename, **kwargs): - "Load collection from a directory." - traverse_yield_all = kwargs['traverse_yield_all'] - db = MultiIndex.load_from_path(filename, traverse_yield_all) - - return db - - -def _load_sbt(filename, **kwargs): - "Load collection from an SBT." - cache_size = kwargs.get('cache_size') - - try: - db = load_sbt_index(filename, cache_size=cache_size) - except (FileNotFoundError, TypeError) as exc: - raise ValueError(exc) - - return db - - -def _load_revindex(filename, **kwargs): - "Load collection from an LCA database/reverse index." - db, _, _ = load_single_database(filename) - return db - - -def _load_sqlite_db(filename, **kwargs): - return load_sqlite_index(filename) - - -def _load_zipfile(filename, **kwargs): - "Load collection from a .zip file." - db = None - if filename.endswith('.zip'): - traverse_yield_all = kwargs['traverse_yield_all'] - try: - db = ZipFileLinearIndex.load(filename, - traverse_yield_all=traverse_yield_all) - except FileNotFoundError as exc: - # turn this into a ValueError => proper exception handling by - # _load_database. - raise ValueError(exc) - - return db - - -# all loader functions, in order. -_loader_functions = [ - ("load from stdin", _load_stdin), - ("load collection from sqlitedb", _load_sqlite_db), - ("load from standalone manifest", _load_standalone_manifest), - ("load from path (file or directory)", _multiindex_load_from_path), - ("load from file list", _multiindex_load_from_pathlist), - ("load SBT", _load_sbt), - ("load revindex", _load_revindex), - ("load collection from zipfile", _load_zipfile), - ] - - -def _load_database(filename, traverse_yield_all, *, cache_size=None): - """Load file as a database - list of signatures, LCA, SBT, etc. - - Return Index object. - - This is an internal function used by other functions in sourmash_args. - """ - loaded = False - - # iterate through loader functions, trying them all. Catch ValueError - # but nothing else. - for n, (desc, load_fn) in enumerate(_loader_functions): - try: - debug_literal(f"_load_databases: trying loader fn {n} '{desc}'") - db = load_fn(filename, - traverse_yield_all=traverse_yield_all, - cache_size=cache_size) - except ValueError: - debug_literal(f"_load_databases: FAIL on fn {n} {desc}.") - debug_literal(traceback.format_exc()) - - if db is not None: - loaded = True - debug_literal("_load_databases: success!") - break - - # check to see if it's a FASTA/FASTQ record (i.e. screed loadable) - # so we can provide a better error message to users. - if not loaded: - successful_screed_load = False - it = None - try: - # CTB: could be kind of time consuming for a big record, but at the - # moment screed doesn't expose format detection cleanly. - with screed.open(filename) as it: - _ = next(iter(it)) - successful_screed_load = True - except: - pass - - if successful_screed_load: - raise ValueError(f"Error while reading signatures from '{filename}' - got sequences instead! Is this a FASTA/FASTQ file?") - - if not loaded: - raise ValueError(f"Error while reading signatures from '{filename}'.") - - if loaded: # this is a bit redundant but safe > sorry - assert db is not None - - return db - - -def load_file_as_index(filename, *, yield_all_files=False): - """Load 'filename' as a database; generic database loader. - - If 'filename' contains an SBT or LCA indexed database, or a regular - Zip file, will return the appropriate objects. If a Zip file and - yield_all_files=True, will try to load all files within zip, not just - .sig files. - - If 'filename' is a JSON file containing one or more signatures, will - return an Index object containing those signatures. - - If 'filename' is a directory, will load *.sig underneath - this directory into an Index object. If yield_all_files=True, will - attempt to load all files. - """ - return _load_database(filename, yield_all_files) - - -def load_file_as_signatures(filename, *, select_moltype=None, ksize=None, - picklist=None, - yield_all_files=False, - progress=None, - pattern=None, - _use_manifest=True): - """Load 'filename' as a collection of signatures. Return an iterable. - - If 'filename' contains an SBT or LCA indexed database, or a regular - Zip file, will return a signatures() generator. If a Zip file and - yield_all_files=True, will try to load all files within zip, not just - .sig files. - - If 'filename' is a JSON file containing one or more signatures, will - return a list of those signatures. - - If 'filename' is a directory, will load *.sig - underneath this directory into a list of signatures. If - yield_all_files=True, will attempt to load all files. - - Applies selector function if select_moltype, ksize or picklist are given. - - 'pattern' is a function that returns True on matching values. - """ - if progress: - progress.notify(filename) - - db = _load_database(filename, yield_all_files) - - # test fixture ;) - if not _use_manifest and db.manifest: - db.manifest = None - - db = db.select(moltype=select_moltype, ksize=ksize) - - # apply pattern search & picklist - db = apply_picklist_and_pattern(db, picklist, pattern) - - loader = db.signatures() - - if progress is not None: - return progress.start_file(filename, loader) - else: - return loader - - def load_pathlist_from_file(filename): "Load a list-of-files text file." try: @@ -894,8 +678,6 @@ def get_manifest(idx, *, require=True, rebuild=False): In the case where `require=False` and a manifest cannot be built, may return None. Otherwise always returns a manifest. """ - from sourmash.index import CollectionManifest - m = idx.manifest # has one, and don't want to rebuild? easy! return! @@ -921,293 +703,48 @@ def get_manifest(idx, *, require=True, rebuild=False): return m -# -# enum and classes for saving signatures progressively -# - -def _get_signatures_from_rust(siglist): - # this deals with a disconnect between the way Rust - # and Python handle signatures; Python expects one - # minhash (and hence one md5sum) per signature, while - # Rust supports multiple. For now, go through serializing - # and deserializing the signature! See issue #1167 for more. - json_str = sourmash.save_signatures(siglist) - for ss in sourmash.load_signatures(json_str): - yield ss - - -class _BaseSaveSignaturesToLocation: - "Base signature saving class. Track location (if any) and count." - def __init__(self, location): - self.location = location - self.count = 0 - - def __repr__(self): - raise NotImplementedError - - def __len__(self): - return self.count - - def __enter__(self): - "provide context manager functionality" - self.open() - return self - - def __exit__(self, type, value, traceback): - "provide context manager functionality" - self.close() - - def add(self, ss): - self.count += 1 - - def add_many(self, sslist): - for ss in sslist: - self.add(ss) - - -class SaveSignatures_NoOutput(_BaseSaveSignaturesToLocation): - "Do not save signatures." - def __repr__(self): - return 'SaveSignatures_NoOutput()' - - def open(self): - pass - - def close(self): - pass - - -class SaveSignatures_Directory(_BaseSaveSignaturesToLocation): - "Save signatures within a directory, using md5sum names." - def __init__(self, location): - super().__init__(location) - - def __repr__(self): - return f"SaveSignatures_Directory('{self.location}')" - - def close(self): - pass - - def open(self): - try: - os.mkdir(self.location) - except FileExistsError: - pass - except: - notify(f"ERROR: cannot create signature output directory '{self.location}'") - sys.exit(-1) - - def add(self, ss): - super().add(ss) - md5 = ss.md5sum() - - # don't overwrite even if duplicate md5sum - outname = os.path.join(self.location, f"{md5}.sig.gz") - if os.path.exists(outname): - i = 0 - while 1: - outname = os.path.join(self.location, f"{md5}_{i}.sig.gz") - if not os.path.exists(outname): - break - i += 1 - - with gzip.open(outname, "wb") as fp: - sigmod.save_signatures([ss], fp, compression=1) - - -class SaveSignatures_SqliteIndex(_BaseSaveSignaturesToLocation): - "Save signatures within a directory, using md5sum names." - def __init__(self, location): - super().__init__(location) - self.location = location - self.idx = None - self.cursor = None - - def __repr__(self): - return f"SaveSignatures_SqliteIndex('{self.location}')" - - def close(self): - self.idx.commit() - self.cursor.execute('VACUUM') - self.idx.close() - - def open(self): - self.idx = SqliteIndex.create(self.location, append=True) - self.cursor = self.idx.cursor() - - def add(self, add_sig): - for ss in _get_signatures_from_rust([add_sig]): - super().add(ss) - self.idx.insert(ss, cursor=self.cursor, commit=False) - - # commit every 1000 signatures. - if self.count % 1000 == 0: - self.idx.commit() - - -class SaveSignatures_SigFile(_BaseSaveSignaturesToLocation): - "Save signatures to a .sig JSON file." - def __init__(self, location): - super().__init__(location) - self.keep = [] - self.compress = 0 - if self.location.endswith('.gz'): - self.compress = 1 - - def __repr__(self): - return f"SaveSignatures_SigFile('{self.location}')" - - def open(self): - pass - - def close(self): - if self.location == '-': - sourmash.save_signatures(self.keep, sys.stdout) - else: - # text mode? encode in utf-8 - mode = "w" - encoding = 'utf-8' - - # compressed? bytes & binary. - if self.compress: - encoding = None - mode = "wb" - - with open(self.location, mode, encoding=encoding) as fp: - sourmash.save_signatures(self.keep, fp, - compression=self.compress) - def add(self, ss): - super().add(ss) - self.keep.append(ss) - - -class SaveSignatures_ZipFile(_BaseSaveSignaturesToLocation): - "Save compressed signatures in an uncompressed Zip file." - def __init__(self, location): - super().__init__(location) - self.storage = None - - def __repr__(self): - return f"SaveSignatures_ZipFile('{self.location}')" +def load_file_as_signatures(filename, *, select_moltype=None, ksize=None, + picklist=None, + yield_all_files=False, + progress=None, + pattern=None, + _use_manifest=True): + """Load 'filename' as a collection of signatures. Return an iterable. - def close(self): - # finish constructing manifest object & save - manifest = CollectionManifest(self.manifest_rows) - manifest_name = f"SOURMASH-MANIFEST.csv" + If 'filename' contains an SBT or LCA indexed database, or a regular + Zip file, will return a signatures() generator. If a Zip file and + yield_all_files=True, will try to load all files within zip, not just + .sig files. - manifest_fp = StringIO() - manifest.write_to_csv(manifest_fp, write_header=True) - manifest_data = manifest_fp.getvalue().encode("utf-8") + If 'filename' is a JSON file containing one or more signatures, will + return a list of those signatures. - self.storage.save(manifest_name, manifest_data, overwrite=True, - compress=True) - self.storage.flush() - self.storage.close() + If 'filename' is a directory, will load *.sig + underneath this directory into a list of signatures. If + yield_all_files=True, will attempt to load all files. - def open(self): - from .sbt_storage import ZipStorage + Applies selector function if select_moltype, ksize or picklist are given. - do_create = True - if os.path.exists(self.location): - do_create = False + 'pattern' is a function that returns True on matching values. + """ + if progress: + progress.notify(filename) - storage = None - try: - storage = ZipStorage(self.location, mode="w") - except zipfile.BadZipFile: - pass + db = _load_database(filename, yield_all_files) - if storage is None: - raise ValueError(f"File '{self.location}' cannot be opened as a zip file.") + # test fixture ;) + if not _use_manifest and db.manifest: + db.manifest = None - if not storage.subdir: - storage.subdir = 'signatures' + db = db.select(moltype=select_moltype, ksize=ksize) - # now, try to load manifest - try: - manifest_data = storage.load('SOURMASH-MANIFEST.csv') - except (FileNotFoundError, KeyError): - # if file already exists must have manifest... - if not do_create: - raise ValueError(f"Cannot add to existing zipfile '{self.location}' without a manifest") - self.manifest_rows = [] - else: - # success! decode manifest_data, create manifest rows => append. - manifest_data = manifest_data.decode('utf-8') - manifest_fp = StringIO(manifest_data) - manifest = CollectionManifest.load_from_csv(manifest_fp) - self.manifest_rows = list(manifest._select()) + # apply pattern search & picklist + db = apply_picklist_and_pattern(db, picklist, pattern) - self.storage = storage + loader = db.signatures() - def _exists(self, name): - try: - self.storage.load(name) - return True - except KeyError: - return False - - def add(self, add_sig): - if not self.storage: - raise ValueError("this output is not open") - - for ss in _get_signatures_from_rust([add_sig]): - buf = sigmod.save_signatures([ss], compression=1) - md5 = ss.md5sum() - - storage = self.storage - path = f'{storage.subdir}/{md5}.sig.gz' - location = storage.save(path, buf) - - # update manifest - row = CollectionManifest.make_manifest_row(ss, location, - include_signature=False) - self.manifest_rows.append(row) - super().add(ss) - - -class SigFileSaveType(Enum): - NO_OUTPUT = 0 - SIGFILE = 1 - SIGFILE_GZ = 2 - DIRECTORY = 3 - ZIPFILE = 4 - SQLITEDB = 5 - -_save_classes = { - SigFileSaveType.NO_OUTPUT: SaveSignatures_NoOutput, - SigFileSaveType.SIGFILE: SaveSignatures_SigFile, - SigFileSaveType.SIGFILE_GZ: SaveSignatures_SigFile, - SigFileSaveType.DIRECTORY: SaveSignatures_Directory, - SigFileSaveType.ZIPFILE: SaveSignatures_ZipFile, - SigFileSaveType.SQLITEDB: SaveSignatures_SqliteIndex, -} - - -def SaveSignaturesToLocation(filename, *, force_type=None): - """Create and return an appropriate object for progressive saving of - signatures.""" - save_type = None - if not force_type: - if filename is None: - save_type = SigFileSaveType.NO_OUTPUT - elif filename.endswith('/'): - save_type = SigFileSaveType.DIRECTORY - elif filename.endswith('.gz'): - save_type = SigFileSaveType.SIGFILE_GZ - elif filename.endswith('.zip'): - save_type = SigFileSaveType.ZIPFILE - elif filename.endswith('.sqldb'): - save_type = SigFileSaveType.SQLITEDB - else: - # default to SIGFILE intentionally! - save_type = SigFileSaveType.SIGFILE + if progress is not None: + return progress.start_file(filename, loader) else: - save_type = force_type - - cls = _save_classes.get(save_type) - if cls is None: - raise Exception("invalid save type; this should never happen!?") - - return cls(filename) + return loader diff --git a/tests/test_api.py b/tests/test_api.py index 73f9ffc7a4..ccaf321df6 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -69,7 +69,7 @@ def test_load_fasta_as_signature(): # try loading a fasta file - should fail with informative exception testfile = utils.get_test_data('short.fa') - with pytest.raises(ValueError) as exc: + with pytest.raises(Exception) as exc: idx = sourmash.load_file_as_index(testfile) print(exc.value) diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py index e69a9b0c68..4a6672408e 100644 --- a/tests/test_index_protocol.py +++ b/tests/test_index_protocol.py @@ -75,7 +75,7 @@ def build_sbt_index_save_load(runtmp): def build_zipfile_index(runtmp): - from sourmash.sourmash_args import SaveSignatures_ZipFile + from sourmash.save_load import SaveSignatures_ZipFile location = runtmp.output('index.zip') with SaveSignatures_ZipFile(location) as save_sigs: diff --git a/tests/test_plugin_framework.py b/tests/test_plugin_framework.py new file mode 100644 index 0000000000..76b22bb730 --- /dev/null +++ b/tests/test_plugin_framework.py @@ -0,0 +1,261 @@ +""" +Test the plugin framework in sourmash.plugins, which uses importlib.metadata +entrypoints. + +CTB TODO: +* check name? +""" + +import pytest +import sourmash +from sourmash.logging import set_quiet + +import sourmash_tst_utils as utils +from sourmash import plugins +from sourmash.index import LinearIndex +from sourmash.save_load import (Base_SaveSignaturesToLocation, + SaveSignaturesToLocation) + + +class FakeEntryPoint: + """ + A class that stores a name and an object to be returned on 'load()'. + Mocks the EntryPoint class used by importlib.metadata. + """ + def __init__(self, name, load_obj): + self.name = name + self.load_obj = load_obj + + def load(self): + return self.load_obj + +# +# Test basic features of the load_from plugin hook. +# + +class Test_EntryPointBasics_LoadFrom: + def get_some_sigs(self, location, *args, **kwargs): + ss2 = utils.get_test_data('2.fa.sig') + ss47 = utils.get_test_data('47.fa.sig') + ss63 = utils.get_test_data('63.fa.sig') + + sig2 = sourmash.load_one_signature(ss2, ksize=31) + sig47 = sourmash.load_one_signature(ss47, ksize=31) + sig63 = sourmash.load_one_signature(ss63, ksize=31) + + lidx = LinearIndex([sig2, sig47, sig63], location) + + return lidx + get_some_sigs.priority = 1 + + def setup_method(self): + self.saved_plugins = plugins._plugin_load_from + plugins._plugin_load_from = [FakeEntryPoint('test_load', self.get_some_sigs)] + + def teardown_method(self): + plugins._plugin_load_from = self.saved_plugins + + def test_load_1(self): + ps = list(plugins.get_load_from_functions()) + assert len(ps) == 1 + + def test_load_2(self, runtmp): + fake_location = runtmp.output('passed-through location') + idx = sourmash.load_file_as_index(fake_location) + print(idx, idx.location) + + assert len(idx) == 3 + assert idx.location == fake_location + + +class Test_EntryPoint_LoadFrom_Priority: + def get_some_sigs(self, location, *args, **kwargs): + ss2 = utils.get_test_data('2.fa.sig') + ss47 = utils.get_test_data('47.fa.sig') + ss63 = utils.get_test_data('63.fa.sig') + + sig2 = sourmash.load_one_signature(ss2, ksize=31) + sig47 = sourmash.load_one_signature(ss47, ksize=31) + sig63 = sourmash.load_one_signature(ss63, ksize=31) + + lidx = LinearIndex([sig2, sig47, sig63], location) + + return lidx + get_some_sigs.priority = 5 + + def set_called_flag_1(self, location, *args, **kwargs): + # high priority 1, raise ValueError + print('setting flag 1') + self.was_called_flag_1 = True + raise ValueError + set_called_flag_1.priority = 1 + + def set_called_flag_2(self, location, *args, **kwargs): + # high priority 2, return None + print('setting flag 2') + self.was_called_flag_2 = True + + return None + set_called_flag_2.priority = 2 + + def set_called_flag_3(self, location, *args, **kwargs): + # lower priority 10, should not be called + print('setting flag 3') + self.was_called_flag_3 = True + + return None + set_called_flag_3.priority = 10 + + def setup_method(self): + self.saved_plugins = plugins._plugin_load_from + plugins._plugin_load_from = [ + FakeEntryPoint('test_load', self.get_some_sigs), + FakeEntryPoint('test_load_2', self.set_called_flag_1), + FakeEntryPoint('test_load_3', self.set_called_flag_2), + FakeEntryPoint('test_load_4', self.set_called_flag_3) + ] + self.was_called_flag_1 = False + self.was_called_flag_2 = False + self.was_called_flag_3 = False + + def teardown_method(self): + plugins._plugin_load_from = self.saved_plugins + + def test_load_1(self): + ps = list(plugins.get_load_from_functions()) + assert len(ps) == 4 + + assert not self.was_called_flag_1 + assert not self.was_called_flag_2 + assert not self.was_called_flag_3 + + def test_load_2(self, runtmp): + fake_location = runtmp.output('passed-through location') + idx = sourmash.load_file_as_index(fake_location) + print(idx, idx.location) + + assert len(idx) == 3 + assert idx.location == fake_location + + assert self.was_called_flag_1 + assert self.was_called_flag_2 + assert not self.was_called_flag_3 + + +# +# Test basic features of the save_to plugin hook. +# + +class FakeSaveClass(Base_SaveSignaturesToLocation): + """ + A fake save class that just records what was sent to it. + """ + priority = 50 + + def __init__(self, location): + super().__init__(location) + self.keep = [] + + @classmethod + def matches(cls, location): + if location: + return location.endswith('.this-is-a-test') + + def add(self, ss): + super().add(ss) + self.keep.append(ss) + + +class FakeSaveClass_HighPriority(FakeSaveClass): + priority = 1 + + +class Test_EntryPointBasics_SaveTo: + # test the basics + def setup_method(self): + self.saved_plugins = plugins._plugin_save_to + plugins._plugin_save_to = [FakeEntryPoint('test_save', FakeSaveClass)] + + def teardown_method(self): + plugins._plugin_save_to = self.saved_plugins + + def test_save_1(self): + ps = list(plugins.get_save_to_functions()) + print(ps) + assert len(ps) == 1 + + def test_save_2(self, runtmp): + # load some signatures to save + ss2 = utils.get_test_data('2.fa.sig') + ss47 = utils.get_test_data('47.fa.sig') + ss63 = utils.get_test_data('63.fa.sig') + + sig2 = sourmash.load_one_signature(ss2, ksize=31) + sig47 = sourmash.load_one_signature(ss47, ksize=31) + sig63 = sourmash.load_one_signature(ss63, ksize=31) + + # build a fake location that matches the FakeSaveClass + # extension + fake_location = runtmp.output('out.this-is-a-test') + + # this should use the plugin architecture to return an object + # of type FakeSaveClass, with the three signatures in it. + x = SaveSignaturesToLocation(fake_location) + with x as save_sig: + save_sig.add(sig2) + save_sig.add(sig47) + save_sig.add(sig63) + + print(len(x)) + print(x.keep) + + assert isinstance(x, FakeSaveClass) + assert x.keep == [sig2, sig47, sig63] + + +class Test_EntryPointPriority_SaveTo: + # test that priority is observed + + def setup_method(self): + self.saved_plugins = plugins._plugin_save_to + plugins._plugin_save_to = [ + FakeEntryPoint('test_save', FakeSaveClass), + FakeEntryPoint('test_save2', FakeSaveClass_HighPriority), + ] + + def teardown_method(self): + plugins._plugin_save_to = self.saved_plugins + + def test_save_1(self): + ps = list(plugins.get_save_to_functions()) + print(ps) + assert len(ps) == 2 + + def test_save_2(self, runtmp): + # load some signatures to save + ss2 = utils.get_test_data('2.fa.sig') + ss47 = utils.get_test_data('47.fa.sig') + ss63 = utils.get_test_data('63.fa.sig') + + sig2 = sourmash.load_one_signature(ss2, ksize=31) + sig47 = sourmash.load_one_signature(ss47, ksize=31) + sig63 = sourmash.load_one_signature(ss63, ksize=31) + + # build a fake location that matches the FakeSaveClass + # extension + fake_location = runtmp.output('out.this-is-a-test') + + # this should use the plugin architecture to return an object + # of type FakeSaveClass, with the three signatures in it. + x = SaveSignaturesToLocation(fake_location) + with x as save_sig: + save_sig.add(sig2) + save_sig.add(sig47) + save_sig.add(sig63) + + print(len(x)) + print(x.keep) + + assert isinstance(x, FakeSaveClass_HighPriority) + assert x.keep == [sig2, sig47, sig63] + assert x.priority == 1 diff --git a/tests/test_sourmash_sketch.py b/tests/test_sourmash_sketch.py index fc4cb2373a..c4f5ac14fe 100644 --- a/tests/test_sourmash_sketch.py +++ b/tests/test_sourmash_sketch.py @@ -235,7 +235,7 @@ def test_dna_multiple_ksize(): assert not params.hp assert not params.protein - from sourmash.sourmash_args import _get_signatures_from_rust + from sourmash.save_load import _get_signatures_from_rust siglist = factory() ksizes = set()