diff --git a/doc/dev_plugins.md b/doc/dev_plugins.md
new file mode 100644
index 0000000000..f8eddfee63
--- /dev/null
+++ b/doc/dev_plugins.md
@@ -0,0 +1,75 @@
+# sourmash plugins via Python entry points
+
+As of version 4.7.0, sourmash has experimental support for Python
+plugins to load and save signatures in different ways (e.g. file
+formats, RPC servers, databases, etc.). This support is provided via
+the "entry points" mechanism supplied by
+[`importlib.metadata`](https://docs.python.org/3/library/importlib.metadata.html)
+and documented
+[here](https://setuptools.pypa.io/en/latest/userguide/entry_point.html).
+
+```{note}
+Note: The plugin API is _not_ finalized or subject to semantic
+versioning just yet! Please subscribe to
+[sourmash#1353](https://github.com/sourmash-bio/sourmash/issues/1353)
+if you want to keep up to date on plugin support.
+```
+
+You can define entry points in the `pyproject.toml` file
+like so:
+
+```
+[project.entry-points."sourmash.load_from"]
+a_reader = "module_name:load_sketches"
+
+[project.entry-points."sourmash.save_to"]
+a_writer = "module_name:SaveSignatures_WriteFile"
+```
+
+Here, `module_name` should be the name of the module to import.
+`load_sketches` should be a function that takes a location along with
+arbitrary keyword arguments and returns an `Index` object
+(e.g. `LinearIndex` for a collection of in-memory
+signatures). `SaveSignatures_WriteFile` should be a class that
+subclasses `BaseSave_SignaturesToLocation` and implements its own
+mechanisms of saving signatures. See the `sourmash.save_load` module
+for saving and loading code already used in sourmash.
+
+Note that if the function or class has a `priority` attribute, this will
+be used to determine the order in which the plugins are called.
+
+The `name` attribute of the plugin (`a_reader` and `a_writer` in
+`pyproject.toml`, above) is only used in debugging.
+
+## Templates and examples
+
+If you want to create your own plug-in, you can start with the
+[sourmash_plugin_template](https://github.com/sourmash-bio/sourmash_plugin_template) repo.
+
+Some (early stage) plugins are also available as examples:
+
+* [sourmash-bio/sourmash_plugin_load_urls](https://github.com/sourmash-bio/sourmash_plugin_load_urls) - load signatures and CSV manifests via [fsspec](https://filesystem-spec.readthedocs.io/).
+* [sourmash-bio/sourmash_plugin_avro](https://github.com/sourmash-bio/sourmash_plugin_avro) - use [Apache Avro](https://avro.apache.org/) as a serialization format.
+
+## Debugging plugins
+
+`sourmash sig cat -o ` is a simple way to
+invoke a `save_to` plugin. Use `-d` to turn on debugging output.
+
+`sourmash sig describe ` is a simple way to invoke
+a `load_from` plugin. Use `-d` to turn on debugging output.
+
+## Semantic versioning and listing sourmash as a dependency
+
+Plugins should probably list sourmash as a dependency for installation.
+
+Once plugins are officially supported by sourmash, the plugin API will
+be under [semantic versioning constraints](https://semver.org/). That
+means that you should constrain plugins to depend on sourmash only up
+to the next major version, e.g. sourmash v5.
+
+Specifically, we suggest placing something like:
+```
+dependencies = ['sourmash>=4.8.0,<5']
+```
+in your `pyproject.toml` file.
diff --git a/doc/developer.md b/doc/developer.md
index 1e4d4dbbeb..b2a968de31 100644
--- a/doc/developer.md
+++ b/doc/developer.md
@@ -1,3 +1,7 @@
+```{contents} Contents
+:depth: 3
+```
+
# Developer information
## Development environment
@@ -280,7 +284,7 @@ Some installation issues can be solved by simply removing the intermediate build
make clean
```
-## Contents
+## Additional developer-focused documents
```{toctree}
:maxdepth: 2
@@ -289,4 +293,6 @@ release
requirements
storage
release-notes/releases
+dev_plugins
```
+
diff --git a/src/sourmash/cli/sig/cat.py b/src/sourmash/cli/sig/cat.py
index d251db0b4e..ed85932f5f 100644
--- a/src/sourmash/cli/sig/cat.py
+++ b/src/sourmash/cli/sig/cat.py
@@ -31,6 +31,10 @@ def subparser(subparsers):
'-q', '--quiet', action='store_true',
help='suppress non-error output'
)
+ subparser.add_argument(
+ '-d', '--debug', action='store_true',
+ help='provide debugging output'
+ )
subparser.add_argument(
'-o', '--output', metavar='FILE', default='-',
help='output signature to this file (default stdout)'
diff --git a/src/sourmash/exceptions.py b/src/sourmash/exceptions.py
index 4895c50947..b2f18c12d2 100644
--- a/src/sourmash/exceptions.py
+++ b/src/sourmash/exceptions.py
@@ -25,6 +25,11 @@ def __init__(self):
SourmashError.__init__(self, "This index format is not supported in this version of sourmash")
+class IndexNotLoaded(SourmashError):
+ def __init__(self, msg):
+ SourmashError.__init__(self, f"Cannot load sourmash index: {str(msg)}")
+
+
def _make_error(error_name, base=SourmashError, code=None):
class Exc(base):
pass
diff --git a/src/sourmash/plugins.py b/src/sourmash/plugins.py
new file mode 100644
index 0000000000..2a786d6d24
--- /dev/null
+++ b/src/sourmash/plugins.py
@@ -0,0 +1,66 @@
+"""
+Support for plugins to sourmash via importlib.metadata entrypoints.
+
+Plugin entry point names:
+* 'sourmash.load_from' - Index class loading.
+* 'sourmash.save_to' - Signature saving.
+* 'sourmash.picklist_filters' - extended Picklist functionality.
+
+CTB TODO:
+
+* consider using something other than 'name' for loader fn name. Maybe __doc__?
+* try implement picklist plugin?
+"""
+
+DEFAULT_LOAD_FROM_PRIORITY = 99
+DEFAULT_SAVE_TO_PRIORITY = 99
+
+from .logging import debug_literal
+
+# cover for older versions of Python that don't support selection on load
+# (the 'group=' below).
+from importlib.metadata import entry_points
+
+# load 'load_from' entry points. NOTE: this executes on import of this module.
+try:
+ _plugin_load_from = entry_points(group='sourmash.load_from')
+except TypeError:
+ from importlib_metadata import entry_points
+ _plugin_load_from = entry_points(group='sourmash.load_from')
+
+# load 'save_to' entry points as well.
+_plugin_save_to = entry_points(group='sourmash.save_to')
+
+
+def get_load_from_functions():
+ "Load the 'load_from' plugins and yield tuples (priority, name, fn)."
+ debug_literal(f"load_from plugins: {_plugin_load_from}")
+
+ # Load each plugin,
+ for plugin in _plugin_load_from:
+ loader_fn = plugin.load()
+
+ # get 'priority' if it is available
+ priority = getattr(loader_fn, 'priority', DEFAULT_LOAD_FROM_PRIORITY)
+
+ # retrieve name (which is specified by plugin?)
+ name = plugin.name
+ debug_literal(f"plugins.load_from_functions: got '{name}', priority={priority}")
+ yield priority, name, loader_fn
+
+
+def get_save_to_functions():
+ "Load the 'save_to' plugins and yield tuples (priority, fn)."
+ debug_literal(f"save_to plugins: {_plugin_save_to}")
+
+ # Load each plugin,
+ for plugin in _plugin_save_to:
+ save_cls = plugin.load()
+
+ # get 'priority' if it is available
+ priority = getattr(save_cls, 'priority', DEFAULT_SAVE_TO_PRIORITY)
+
+ # retrieve name (which is specified by plugin?)
+ name = plugin.name
+ debug_literal(f"plugins.save_to_functions: got '{name}', priority={priority}")
+ yield priority, save_cls
diff --git a/src/sourmash/save_load.py b/src/sourmash/save_load.py
new file mode 100644
index 0000000000..bb842bd02e
--- /dev/null
+++ b/src/sourmash/save_load.py
@@ -0,0 +1,530 @@
+"""
+Index object/sigfile loading and signature saving code.
+
+This is the middleware code responsible for loading and saving signatures
+in a variety of ways.
+
+---
+
+Command-line functionality goes in sourmash_args.py.
+
+Low-level JSON reading/writing is in signature.py.
+
+Index objects are implemented in the index submodule.
+
+Public API:
+
+* load_file_as_index(filename, ...) -- load a sourmash.Index class
+* SaveSignaturesToLocation(filename) - bulk signature output
+
+APIs for plugins to use:
+
+* class Base_SaveSignaturesToLocation - to implement a new output method.
+
+CTB TODO:
+* consider replacing ValueError with IndexNotLoaded in the future.
+"""
+import sys
+import os
+import gzip
+from io import StringIO
+import zipfile
+import itertools
+import traceback
+
+import screed
+import sourmash
+
+from . import plugins as sourmash_plugins
+from .logging import notify, debug_literal
+from .exceptions import IndexNotLoaded
+
+from .index.sqlite_index import load_sqlite_index, SqliteIndex
+from .sbtmh import load_sbt_index
+from .lca.lca_db import load_single_database
+from . import signature as sigmod
+from .index import (LinearIndex, ZipFileLinearIndex, MultiIndex)
+from .manifest import CollectionManifest
+
+
+def load_file_as_index(filename, *, yield_all_files=False):
+ """Load 'filename' as a database; generic database loader.
+
+ If 'filename' contains an SBT or LCA indexed database, or a regular
+ Zip file, will return the appropriate objects. If a Zip file and
+ yield_all_files=True, will try to load all files within zip, not just
+ .sig files.
+
+ If 'filename' is a JSON file containing one or more signatures, will
+ return an Index object containing those signatures.
+
+ If 'filename' is a directory, will load *.sig underneath
+ this directory into an Index object. If yield_all_files=True, will
+ attempt to load all files.
+ """
+ return _load_database(filename, yield_all_files)
+
+
+def SaveSignaturesToLocation(location):
+ """
+ Provides a context manager that saves signatures in various output formats.
+
+ Usage:
+
+ with SaveSignaturesToLocation(filename_or_location) as save_sigs:
+ save_sigs.add(sig_obj)
+ """
+ save_list = itertools.chain(_save_classes,
+ sourmash_plugins.get_save_to_functions())
+ for priority, cls in sorted(save_list, key=lambda x:x[0]):
+ debug_literal(f"trying to match save function {cls}, priority={priority}")
+
+ if cls.matches(location):
+ debug_literal(f"{cls} is a match!")
+ return cls(location)
+
+ raise Exception(f"cannot determine how to open location {location} for saving; this should never happen!?")
+
+### Implementation machinery for _load_databases
+
+
+def _load_database(filename, traverse_yield_all, *, cache_size=None):
+ """Load file as a database - list of signatures, LCA, SBT, etc.
+
+ Return Index object.
+
+ This is an internal function used by other functions in sourmash_args.
+ """
+ loaded = False
+
+ # load plugins
+ plugin_fns = sourmash_plugins.get_load_from_functions()
+
+ # aggregate with default load_from functions & sort by priority
+ load_from_functions = sorted(itertools.chain(_loader_functions,
+ plugin_fns))
+
+ # iterate through loader functions, sorted by priority; try them all.
+ # Catch ValueError & IndexNotLoaded but nothing else.
+ for (priority, desc, load_fn) in load_from_functions:
+ db = None
+ try:
+ debug_literal(f"_load_databases: trying loader fn - priority {priority} - '{desc}'")
+ db = load_fn(filename,
+ traverse_yield_all=traverse_yield_all,
+ cache_size=cache_size)
+ except (ValueError, IndexNotLoaded):
+ debug_literal(f"_load_databases: FAIL with ValueError: on fn {desc}.")
+ debug_literal(traceback.format_exc())
+ debug_literal("(continuing past exception)")
+
+ if db is not None:
+ loaded = True
+ debug_literal("_load_databases: success!")
+ break
+
+ if loaded:
+ assert db is not None
+ return db
+
+ raise ValueError(f"Error while reading signatures from '{filename}'.")
+
+
+_loader_functions = []
+def add_loader(name, priority):
+ "decorator to add name/priority to _loader_functions"
+ def dec_priority(func):
+ _loader_functions.append((priority, name, func))
+ return func
+ return dec_priority
+
+
+@add_loader("load from stdin", 10)
+def _load_stdin(filename, **kwargs):
+ "Load collection from .sig file streamed in via stdin"
+ db = None
+ if filename == '-':
+ # load as LinearIndex, then pass into MultiIndex to generate a
+ # manifest.
+ lidx = LinearIndex.load(sys.stdin, filename='-')
+ db = MultiIndex.load((lidx,), (None,), parent="-")
+
+ return db
+
+
+@add_loader("load from standalone manifest", 30)
+def _load_standalone_manifest(filename, **kwargs):
+ from sourmash.index import StandaloneManifestIndex
+
+ try:
+ idx = StandaloneManifestIndex.load(filename)
+ except gzip.BadGzipFile as exc:
+ raise IndexNotLoaded(exc)
+
+ return idx
+
+
+@add_loader("load from list of paths", 50)
+def _multiindex_load_from_pathlist(filename, **kwargs):
+ "Load collection from a list of signature/database files"
+ db = MultiIndex.load_from_pathlist(filename)
+
+ return db
+
+
+@add_loader("load from path (file or directory)", 40)
+def _multiindex_load_from_path(filename, **kwargs):
+ "Load collection from a directory."
+ traverse_yield_all = kwargs['traverse_yield_all']
+ db = MultiIndex.load_from_path(filename, traverse_yield_all)
+
+ return db
+
+
+@add_loader("load SBT", 60)
+def _load_sbt(filename, **kwargs):
+ "Load collection from an SBT."
+ cache_size = kwargs.get('cache_size')
+
+ try:
+ db = load_sbt_index(filename, cache_size=cache_size)
+ except (FileNotFoundError, TypeError) as exc:
+ raise IndexNotLoaded(exc)
+
+ return db
+
+
+@add_loader("load revindex", 70)
+def _load_revindex(filename, **kwargs):
+ "Load collection from an LCA database/reverse index."
+ db, _, _ = load_single_database(filename)
+ return db
+
+
+@add_loader("load collection from sqlitedb", 20)
+def _load_sqlite_db(filename, **kwargs):
+ return load_sqlite_index(filename)
+
+
+@add_loader("load collection from zipfile", 80)
+def _load_zipfile(filename, **kwargs):
+ "Load collection from a .zip file."
+ db = None
+ if filename.endswith('.zip'):
+ traverse_yield_all = kwargs['traverse_yield_all']
+ try:
+ db = ZipFileLinearIndex.load(filename,
+ traverse_yield_all=traverse_yield_all)
+ except FileNotFoundError as exc:
+ # turn this into an IndexNotLoaded => proper exception handling by
+ # _load_database.
+ raise IndexNotLoaded(exc)
+
+ return db
+
+
+@add_loader("catch FASTA/FASTQ files and error", 1000)
+def _error_on_fastaq(filename, **kwargs):
+ "This is a tail-end loader that checks for FASTA/FASTQ sequences => err."
+ success = False
+ try:
+ with screed.open(filename) as it:
+ _ = next(iter(it))
+
+ success = True
+ except:
+ pass
+
+ if success:
+ raise Exception(f"Error while reading signatures from '{filename}' - got sequences instead! Is this a FASTA/FASTQ file?")
+
+
+### Implementation machinery for SaveSignaturesToLocation
+
+class Base_SaveSignaturesToLocation:
+ "Base signature saving class. Track location (if any) and count."
+ def __init__(self, location):
+ self.location = location
+ self.count = 0
+
+ @classmethod
+ def matches(cls, location):
+ "returns True when this class should handle a specific location"
+ raise NotImplementedError
+
+ def __repr__(self):
+ raise NotImplementedError
+
+ def __len__(self):
+ return self.count
+
+ def open(self):
+ pass
+
+ def close(self):
+ pass
+
+ def __enter__(self):
+ "provide context manager functionality"
+ self.open()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ "provide context manager functionality"
+ self.close()
+
+ def add(self, ss):
+ self.count += 1
+
+ def add_many(self, sslist):
+ for ss in sslist:
+ self.add(ss)
+
+
+def _get_signatures_from_rust(siglist):
+ # this function deals with a disconnect between the way Rust
+ # and Python handle signatures; Python expects one
+ # minhash (and hence one md5sum) per signature, while
+ # Rust supports multiple. For now, go through serializing
+ # and deserializing the signature! See issue #1167 for more.
+ json_str = sourmash.save_signatures(siglist)
+ for ss in sourmash.load_signatures(json_str):
+ yield ss
+
+
+class SaveSignatures_NoOutput(Base_SaveSignaturesToLocation):
+ "Do not save signatures."
+ def __repr__(self):
+ return 'SaveSignatures_NoOutput()'
+
+ @classmethod
+ def matches(cls, location):
+ return location is None
+
+ def open(self):
+ pass
+
+ def close(self):
+ pass
+
+
+class SaveSignatures_Directory(Base_SaveSignaturesToLocation):
+ "Save signatures within a directory, using md5sum names."
+ def __init__(self, location):
+ super().__init__(location)
+
+ def __repr__(self):
+ return f"SaveSignatures_Directory('{self.location}')"
+
+ @classmethod
+ def matches(cls, location):
+ "anything ending in /"
+ if location:
+ return location.endswith('/')
+
+ def close(self):
+ pass
+
+ def open(self):
+ try:
+ os.mkdir(self.location)
+ except FileExistsError:
+ pass
+ except:
+ notify(f"ERROR: cannot create signature output directory '{self.location}'")
+ sys.exit(-1)
+
+ def add(self, ss):
+ super().add(ss)
+ md5 = ss.md5sum()
+
+ # don't overwrite even if duplicate md5sum
+ outname = os.path.join(self.location, f"{md5}.sig.gz")
+ if os.path.exists(outname):
+ i = 0
+ while 1:
+ outname = os.path.join(self.location, f"{md5}_{i}.sig.gz")
+ if not os.path.exists(outname):
+ break
+ i += 1
+
+ with gzip.open(outname, "wb") as fp:
+ sigmod.save_signatures([ss], fp, compression=1)
+
+
+class SaveSignatures_SqliteIndex(Base_SaveSignaturesToLocation):
+ "Save signatures within a directory, using md5sum names."
+ def __init__(self, location):
+ super().__init__(location)
+ self.location = location
+ self.idx = None
+ self.cursor = None
+
+ @classmethod
+ def matches(cls, location):
+ "anything ending in .sqldb"
+ if location:
+ return location.endswith('.sqldb')
+
+ def __repr__(self):
+ return f"SaveSignatures_SqliteIndex('{self.location}')"
+
+ def close(self):
+ self.idx.commit()
+ self.cursor.execute('VACUUM')
+ self.idx.close()
+
+ def open(self):
+ self.idx = SqliteIndex.create(self.location, append=True)
+ self.cursor = self.idx.cursor()
+
+ def add(self, add_sig):
+ for ss in _get_signatures_from_rust([add_sig]):
+ super().add(ss)
+ self.idx.insert(ss, cursor=self.cursor, commit=False)
+
+ # commit every 1000 signatures.
+ if self.count % 1000 == 0:
+ self.idx.commit()
+
+
+class SaveSignatures_SigFile(Base_SaveSignaturesToLocation):
+ "Save signatures to a .sig JSON file."
+ def __init__(self, location):
+ super().__init__(location)
+ self.keep = []
+ self.compress = 0
+ if self.location.endswith('.gz'):
+ self.compress = 1
+
+ @classmethod
+ def matches(cls, location):
+ # match anything that is not None or ""
+ return bool(location)
+
+ def __repr__(self):
+ return f"SaveSignatures_SigFile('{self.location}')"
+
+ def open(self):
+ pass
+
+ def close(self):
+ if self.location == '-':
+ sourmash.save_signatures(self.keep, sys.stdout)
+ else:
+ # text mode? encode in utf-8
+ mode = "w"
+ encoding = 'utf-8'
+
+ # compressed? bytes & binary.
+ if self.compress:
+ encoding = None
+ mode = "wb"
+
+ with open(self.location, mode, encoding=encoding) as fp:
+ sourmash.save_signatures(self.keep, fp,
+ compression=self.compress)
+
+ def add(self, ss):
+ super().add(ss)
+ self.keep.append(ss)
+
+
+class SaveSignatures_ZipFile(Base_SaveSignaturesToLocation):
+ "Save compressed signatures in an uncompressed Zip file."
+ def __init__(self, location):
+ super().__init__(location)
+ self.storage = None
+
+ @classmethod
+ def matches(cls, location):
+ "anything ending in .zip"
+ if location:
+ return location.endswith('.zip')
+
+ def __repr__(self):
+ return f"SaveSignatures_ZipFile('{self.location}')"
+
+ def close(self):
+ # finish constructing manifest object & save
+ manifest = CollectionManifest(self.manifest_rows)
+ manifest_name = "SOURMASH-MANIFEST.csv"
+
+ manifest_fp = StringIO()
+ manifest.write_to_csv(manifest_fp, write_header=True)
+ manifest_data = manifest_fp.getvalue().encode("utf-8")
+
+ self.storage.save(manifest_name, manifest_data, overwrite=True,
+ compress=True)
+ self.storage.flush()
+ self.storage.close()
+
+ def open(self):
+ from .sbt_storage import ZipStorage
+
+ do_create = True
+ if os.path.exists(self.location):
+ do_create = False
+
+ storage = None
+ try:
+ storage = ZipStorage(self.location, mode="w")
+ except zipfile.BadZipFile:
+ pass
+
+ if storage is None:
+ raise ValueError(f"File '{self.location}' cannot be opened as a zip file.")
+
+ if not storage.subdir:
+ storage.subdir = 'signatures'
+
+ # now, try to load manifest
+ try:
+ manifest_data = storage.load('SOURMASH-MANIFEST.csv')
+ except (FileNotFoundError, KeyError):
+ # if file already exists must have manifest...
+ if not do_create:
+ raise ValueError(f"Cannot add to existing zipfile '{self.location}' without a manifest")
+ self.manifest_rows = []
+ else:
+ # success! decode manifest_data, create manifest rows => append.
+ manifest_data = manifest_data.decode('utf-8')
+ manifest_fp = StringIO(manifest_data)
+ manifest = CollectionManifest.load_from_csv(manifest_fp)
+ self.manifest_rows = list(manifest._select())
+
+ self.storage = storage
+
+ def _exists(self, name):
+ try:
+ self.storage.load(name)
+ return True
+ except KeyError:
+ return False
+
+ def add(self, add_sig):
+ if not self.storage:
+ raise ValueError("this output is not open")
+
+ for ss in _get_signatures_from_rust([add_sig]):
+ buf = sigmod.save_signatures([ss], compression=1)
+ md5 = ss.md5sum()
+
+ storage = self.storage
+ path = f'{storage.subdir}/{md5}.sig.gz'
+ location = storage.save(path, buf)
+
+ # update manifest
+ row = CollectionManifest.make_manifest_row(ss, location,
+ include_signature=False)
+ self.manifest_rows.append(row)
+ super().add(ss)
+
+
+_save_classes = [
+ (10, SaveSignatures_NoOutput),
+ (20, SaveSignatures_Directory),
+ (30, SaveSignatures_ZipFile),
+ (40, SaveSignatures_SqliteIndex),
+ (1000, SaveSignatures_SigFile),
+]
diff --git a/src/sourmash/sig/__main__.py b/src/sourmash/sig/__main__.py
index 48be17d9ad..7a757aa444 100644
--- a/src/sourmash/sig/__main__.py
+++ b/src/sourmash/sig/__main__.py
@@ -83,7 +83,7 @@ def cat(args):
"""
concatenate all signatures into one file.
"""
- set_quiet(args.quiet)
+ set_quiet(args.quiet, args.debug)
moltype = sourmash_args.calculate_moltype(args)
picklist = sourmash_args.load_picklist(args)
pattern_search = sourmash_args.load_include_exclude_db_patterns(args)
diff --git a/src/sourmash/sourmash_args.py b/src/sourmash/sourmash_args.py
index 3a465b18c1..01fd86460c 100644
--- a/src/sourmash/sourmash_args.py
+++ b/src/sourmash/sourmash_args.py
@@ -22,45 +22,34 @@
* load_query_signature(filename, ...) -- load a single signature for query
* traverse_find_sigs(filenames, ...) -- find all .sig and .sig.gz files
* load_dbs_and_sigs(filenames, query, ...) -- load databases & signatures
-* load_file_as_index(filename, ...) -- load a sourmash.Index class
-* load_file_as_signatures(filename, ...) -- load a list of signatures
* load_pathlist_from_file(filename) -- load a list of paths from a file
* load_many_signatures(locations) -- load many signatures from many files
* get_manifest(idx) -- retrieve or build a manifest from an Index
* class SignatureLoadingProgress - signature loading progress bar
+* load_file_as_signatures(filename, ...) -- load a list of signatures
signature and file output functionality:
-* SaveSignaturesToLocation(filename) - bulk signature output
* class FileOutput - file output context manager that deals w/stdout well
* class FileOutputCSV - file output context manager for CSV files
"""
import sys
import os
import csv
-from enum import Enum
-import traceback
import gzip
-from io import StringIO, TextIOWrapper
+from io import TextIOWrapper
import re
import zipfile
import contextlib
-
-import screed
-import sourmash
-
-from sourmash.sbtmh import load_sbt_index
-from sourmash.lca.lca_db import load_single_database
-import sourmash.exceptions
+import argparse
from .logging import notify, error, debug_literal
-from .index import (LinearIndex, ZipFileLinearIndex, MultiIndex)
-from .index.sqlite_index import load_sqlite_index, SqliteIndex
-from . import signature as sigmod
+from .index import LinearIndex
from .picklist import SignaturePicklist, PickStyle
from .manifest import CollectionManifest
-import argparse
+from .save_load import (SaveSignaturesToLocation, load_file_as_index,
+ _load_database)
DEFAULT_LOAD_K = 31
@@ -70,7 +59,7 @@ def check_scaled_bounds(arg):
f = float(arg)
if f < 0:
- raise argparse.ArgumentTypeError(f"ERROR: scaled value must be positive")
+ raise argparse.ArgumentTypeError("ERROR: scaled value must be positive")
if f < 100:
notify('WARNING: scaled value should be >= 100. Continuing anyway.')
if f > 1e6:
@@ -82,7 +71,7 @@ def check_num_bounds(arg):
f = int(arg)
if f < 0:
- raise argparse.ArgumentTypeError(f"ERROR: num value must be positive")
+ raise argparse.ArgumentTypeError("ERROR: num value must be positive")
if f < 50:
notify('WARNING: num value should be >= 50. Continuing anyway.')
if f > 50000:
@@ -352,211 +341,6 @@ def load_dbs_and_sigs(filenames, query, is_similarity_query, *,
return databases
-def _load_stdin(filename, **kwargs):
- "Load collection from .sig file streamed in via stdin"
- db = None
- if filename == '-':
- # load as LinearIndex, then pass into MultiIndex to generate a
- # manifest.
- lidx = LinearIndex.load(sys.stdin, filename='-')
- db = MultiIndex.load((lidx,), (None,), parent="-")
-
- return db
-
-
-def _load_standalone_manifest(filename, **kwargs):
- from sourmash.index import StandaloneManifestIndex
-
- try:
- idx = StandaloneManifestIndex.load(filename)
- except gzip.BadGzipFile as exc:
- raise ValueError(exc)
-
- return idx
-
-
-def _multiindex_load_from_pathlist(filename, **kwargs):
- "Load collection from a list of signature/database files"
- db = MultiIndex.load_from_pathlist(filename)
-
- return db
-
-
-def _multiindex_load_from_path(filename, **kwargs):
- "Load collection from a directory."
- traverse_yield_all = kwargs['traverse_yield_all']
- db = MultiIndex.load_from_path(filename, traverse_yield_all)
-
- return db
-
-
-def _load_sbt(filename, **kwargs):
- "Load collection from an SBT."
- cache_size = kwargs.get('cache_size')
-
- try:
- db = load_sbt_index(filename, cache_size=cache_size)
- except (FileNotFoundError, TypeError) as exc:
- raise ValueError(exc)
-
- return db
-
-
-def _load_revindex(filename, **kwargs):
- "Load collection from an LCA database/reverse index."
- db, _, _ = load_single_database(filename)
- return db
-
-
-def _load_sqlite_db(filename, **kwargs):
- return load_sqlite_index(filename)
-
-
-def _load_zipfile(filename, **kwargs):
- "Load collection from a .zip file."
- db = None
- if filename.endswith('.zip'):
- traverse_yield_all = kwargs['traverse_yield_all']
- try:
- db = ZipFileLinearIndex.load(filename,
- traverse_yield_all=traverse_yield_all)
- except FileNotFoundError as exc:
- # turn this into a ValueError => proper exception handling by
- # _load_database.
- raise ValueError(exc)
-
- return db
-
-
-# all loader functions, in order.
-_loader_functions = [
- ("load from stdin", _load_stdin),
- ("load collection from sqlitedb", _load_sqlite_db),
- ("load from standalone manifest", _load_standalone_manifest),
- ("load from path (file or directory)", _multiindex_load_from_path),
- ("load from file list", _multiindex_load_from_pathlist),
- ("load SBT", _load_sbt),
- ("load revindex", _load_revindex),
- ("load collection from zipfile", _load_zipfile),
- ]
-
-
-def _load_database(filename, traverse_yield_all, *, cache_size=None):
- """Load file as a database - list of signatures, LCA, SBT, etc.
-
- Return Index object.
-
- This is an internal function used by other functions in sourmash_args.
- """
- loaded = False
-
- # iterate through loader functions, trying them all. Catch ValueError
- # but nothing else.
- for n, (desc, load_fn) in enumerate(_loader_functions):
- try:
- debug_literal(f"_load_databases: trying loader fn {n} '{desc}'")
- db = load_fn(filename,
- traverse_yield_all=traverse_yield_all,
- cache_size=cache_size)
- except ValueError:
- debug_literal(f"_load_databases: FAIL on fn {n} {desc}.")
- debug_literal(traceback.format_exc())
-
- if db is not None:
- loaded = True
- debug_literal("_load_databases: success!")
- break
-
- # check to see if it's a FASTA/FASTQ record (i.e. screed loadable)
- # so we can provide a better error message to users.
- if not loaded:
- successful_screed_load = False
- it = None
- try:
- # CTB: could be kind of time consuming for a big record, but at the
- # moment screed doesn't expose format detection cleanly.
- with screed.open(filename) as it:
- _ = next(iter(it))
- successful_screed_load = True
- except:
- pass
-
- if successful_screed_load:
- raise ValueError(f"Error while reading signatures from '{filename}' - got sequences instead! Is this a FASTA/FASTQ file?")
-
- if not loaded:
- raise ValueError(f"Error while reading signatures from '{filename}'.")
-
- if loaded: # this is a bit redundant but safe > sorry
- assert db is not None
-
- return db
-
-
-def load_file_as_index(filename, *, yield_all_files=False):
- """Load 'filename' as a database; generic database loader.
-
- If 'filename' contains an SBT or LCA indexed database, or a regular
- Zip file, will return the appropriate objects. If a Zip file and
- yield_all_files=True, will try to load all files within zip, not just
- .sig files.
-
- If 'filename' is a JSON file containing one or more signatures, will
- return an Index object containing those signatures.
-
- If 'filename' is a directory, will load *.sig underneath
- this directory into an Index object. If yield_all_files=True, will
- attempt to load all files.
- """
- return _load_database(filename, yield_all_files)
-
-
-def load_file_as_signatures(filename, *, select_moltype=None, ksize=None,
- picklist=None,
- yield_all_files=False,
- progress=None,
- pattern=None,
- _use_manifest=True):
- """Load 'filename' as a collection of signatures. Return an iterable.
-
- If 'filename' contains an SBT or LCA indexed database, or a regular
- Zip file, will return a signatures() generator. If a Zip file and
- yield_all_files=True, will try to load all files within zip, not just
- .sig files.
-
- If 'filename' is a JSON file containing one or more signatures, will
- return a list of those signatures.
-
- If 'filename' is a directory, will load *.sig
- underneath this directory into a list of signatures. If
- yield_all_files=True, will attempt to load all files.
-
- Applies selector function if select_moltype, ksize or picklist are given.
-
- 'pattern' is a function that returns True on matching values.
- """
- if progress:
- progress.notify(filename)
-
- db = _load_database(filename, yield_all_files)
-
- # test fixture ;)
- if not _use_manifest and db.manifest:
- db.manifest = None
-
- db = db.select(moltype=select_moltype, ksize=ksize)
-
- # apply pattern search & picklist
- db = apply_picklist_and_pattern(db, picklist, pattern)
-
- loader = db.signatures()
-
- if progress is not None:
- return progress.start_file(filename, loader)
- else:
- return loader
-
-
def load_pathlist_from_file(filename):
"Load a list-of-files text file."
try:
@@ -894,8 +678,6 @@ def get_manifest(idx, *, require=True, rebuild=False):
In the case where `require=False` and a manifest cannot be built,
may return None. Otherwise always returns a manifest.
"""
- from sourmash.index import CollectionManifest
-
m = idx.manifest
# has one, and don't want to rebuild? easy! return!
@@ -921,293 +703,48 @@ def get_manifest(idx, *, require=True, rebuild=False):
return m
-#
-# enum and classes for saving signatures progressively
-#
-
-def _get_signatures_from_rust(siglist):
- # this deals with a disconnect between the way Rust
- # and Python handle signatures; Python expects one
- # minhash (and hence one md5sum) per signature, while
- # Rust supports multiple. For now, go through serializing
- # and deserializing the signature! See issue #1167 for more.
- json_str = sourmash.save_signatures(siglist)
- for ss in sourmash.load_signatures(json_str):
- yield ss
-
-
-class _BaseSaveSignaturesToLocation:
- "Base signature saving class. Track location (if any) and count."
- def __init__(self, location):
- self.location = location
- self.count = 0
-
- def __repr__(self):
- raise NotImplementedError
-
- def __len__(self):
- return self.count
-
- def __enter__(self):
- "provide context manager functionality"
- self.open()
- return self
-
- def __exit__(self, type, value, traceback):
- "provide context manager functionality"
- self.close()
-
- def add(self, ss):
- self.count += 1
-
- def add_many(self, sslist):
- for ss in sslist:
- self.add(ss)
-
-
-class SaveSignatures_NoOutput(_BaseSaveSignaturesToLocation):
- "Do not save signatures."
- def __repr__(self):
- return 'SaveSignatures_NoOutput()'
-
- def open(self):
- pass
-
- def close(self):
- pass
-
-
-class SaveSignatures_Directory(_BaseSaveSignaturesToLocation):
- "Save signatures within a directory, using md5sum names."
- def __init__(self, location):
- super().__init__(location)
-
- def __repr__(self):
- return f"SaveSignatures_Directory('{self.location}')"
-
- def close(self):
- pass
-
- def open(self):
- try:
- os.mkdir(self.location)
- except FileExistsError:
- pass
- except:
- notify(f"ERROR: cannot create signature output directory '{self.location}'")
- sys.exit(-1)
-
- def add(self, ss):
- super().add(ss)
- md5 = ss.md5sum()
-
- # don't overwrite even if duplicate md5sum
- outname = os.path.join(self.location, f"{md5}.sig.gz")
- if os.path.exists(outname):
- i = 0
- while 1:
- outname = os.path.join(self.location, f"{md5}_{i}.sig.gz")
- if not os.path.exists(outname):
- break
- i += 1
-
- with gzip.open(outname, "wb") as fp:
- sigmod.save_signatures([ss], fp, compression=1)
-
-
-class SaveSignatures_SqliteIndex(_BaseSaveSignaturesToLocation):
- "Save signatures within a directory, using md5sum names."
- def __init__(self, location):
- super().__init__(location)
- self.location = location
- self.idx = None
- self.cursor = None
-
- def __repr__(self):
- return f"SaveSignatures_SqliteIndex('{self.location}')"
-
- def close(self):
- self.idx.commit()
- self.cursor.execute('VACUUM')
- self.idx.close()
-
- def open(self):
- self.idx = SqliteIndex.create(self.location, append=True)
- self.cursor = self.idx.cursor()
-
- def add(self, add_sig):
- for ss in _get_signatures_from_rust([add_sig]):
- super().add(ss)
- self.idx.insert(ss, cursor=self.cursor, commit=False)
-
- # commit every 1000 signatures.
- if self.count % 1000 == 0:
- self.idx.commit()
-
-
-class SaveSignatures_SigFile(_BaseSaveSignaturesToLocation):
- "Save signatures to a .sig JSON file."
- def __init__(self, location):
- super().__init__(location)
- self.keep = []
- self.compress = 0
- if self.location.endswith('.gz'):
- self.compress = 1
-
- def __repr__(self):
- return f"SaveSignatures_SigFile('{self.location}')"
-
- def open(self):
- pass
-
- def close(self):
- if self.location == '-':
- sourmash.save_signatures(self.keep, sys.stdout)
- else:
- # text mode? encode in utf-8
- mode = "w"
- encoding = 'utf-8'
-
- # compressed? bytes & binary.
- if self.compress:
- encoding = None
- mode = "wb"
-
- with open(self.location, mode, encoding=encoding) as fp:
- sourmash.save_signatures(self.keep, fp,
- compression=self.compress)
- def add(self, ss):
- super().add(ss)
- self.keep.append(ss)
-
-
-class SaveSignatures_ZipFile(_BaseSaveSignaturesToLocation):
- "Save compressed signatures in an uncompressed Zip file."
- def __init__(self, location):
- super().__init__(location)
- self.storage = None
-
- def __repr__(self):
- return f"SaveSignatures_ZipFile('{self.location}')"
+def load_file_as_signatures(filename, *, select_moltype=None, ksize=None,
+ picklist=None,
+ yield_all_files=False,
+ progress=None,
+ pattern=None,
+ _use_manifest=True):
+ """Load 'filename' as a collection of signatures. Return an iterable.
- def close(self):
- # finish constructing manifest object & save
- manifest = CollectionManifest(self.manifest_rows)
- manifest_name = f"SOURMASH-MANIFEST.csv"
+ If 'filename' contains an SBT or LCA indexed database, or a regular
+ Zip file, will return a signatures() generator. If a Zip file and
+ yield_all_files=True, will try to load all files within zip, not just
+ .sig files.
- manifest_fp = StringIO()
- manifest.write_to_csv(manifest_fp, write_header=True)
- manifest_data = manifest_fp.getvalue().encode("utf-8")
+ If 'filename' is a JSON file containing one or more signatures, will
+ return a list of those signatures.
- self.storage.save(manifest_name, manifest_data, overwrite=True,
- compress=True)
- self.storage.flush()
- self.storage.close()
+ If 'filename' is a directory, will load *.sig
+ underneath this directory into a list of signatures. If
+ yield_all_files=True, will attempt to load all files.
- def open(self):
- from .sbt_storage import ZipStorage
+ Applies selector function if select_moltype, ksize or picklist are given.
- do_create = True
- if os.path.exists(self.location):
- do_create = False
+ 'pattern' is a function that returns True on matching values.
+ """
+ if progress:
+ progress.notify(filename)
- storage = None
- try:
- storage = ZipStorage(self.location, mode="w")
- except zipfile.BadZipFile:
- pass
+ db = _load_database(filename, yield_all_files)
- if storage is None:
- raise ValueError(f"File '{self.location}' cannot be opened as a zip file.")
+ # test fixture ;)
+ if not _use_manifest and db.manifest:
+ db.manifest = None
- if not storage.subdir:
- storage.subdir = 'signatures'
+ db = db.select(moltype=select_moltype, ksize=ksize)
- # now, try to load manifest
- try:
- manifest_data = storage.load('SOURMASH-MANIFEST.csv')
- except (FileNotFoundError, KeyError):
- # if file already exists must have manifest...
- if not do_create:
- raise ValueError(f"Cannot add to existing zipfile '{self.location}' without a manifest")
- self.manifest_rows = []
- else:
- # success! decode manifest_data, create manifest rows => append.
- manifest_data = manifest_data.decode('utf-8')
- manifest_fp = StringIO(manifest_data)
- manifest = CollectionManifest.load_from_csv(manifest_fp)
- self.manifest_rows = list(manifest._select())
+ # apply pattern search & picklist
+ db = apply_picklist_and_pattern(db, picklist, pattern)
- self.storage = storage
+ loader = db.signatures()
- def _exists(self, name):
- try:
- self.storage.load(name)
- return True
- except KeyError:
- return False
-
- def add(self, add_sig):
- if not self.storage:
- raise ValueError("this output is not open")
-
- for ss in _get_signatures_from_rust([add_sig]):
- buf = sigmod.save_signatures([ss], compression=1)
- md5 = ss.md5sum()
-
- storage = self.storage
- path = f'{storage.subdir}/{md5}.sig.gz'
- location = storage.save(path, buf)
-
- # update manifest
- row = CollectionManifest.make_manifest_row(ss, location,
- include_signature=False)
- self.manifest_rows.append(row)
- super().add(ss)
-
-
-class SigFileSaveType(Enum):
- NO_OUTPUT = 0
- SIGFILE = 1
- SIGFILE_GZ = 2
- DIRECTORY = 3
- ZIPFILE = 4
- SQLITEDB = 5
-
-_save_classes = {
- SigFileSaveType.NO_OUTPUT: SaveSignatures_NoOutput,
- SigFileSaveType.SIGFILE: SaveSignatures_SigFile,
- SigFileSaveType.SIGFILE_GZ: SaveSignatures_SigFile,
- SigFileSaveType.DIRECTORY: SaveSignatures_Directory,
- SigFileSaveType.ZIPFILE: SaveSignatures_ZipFile,
- SigFileSaveType.SQLITEDB: SaveSignatures_SqliteIndex,
-}
-
-
-def SaveSignaturesToLocation(filename, *, force_type=None):
- """Create and return an appropriate object for progressive saving of
- signatures."""
- save_type = None
- if not force_type:
- if filename is None:
- save_type = SigFileSaveType.NO_OUTPUT
- elif filename.endswith('/'):
- save_type = SigFileSaveType.DIRECTORY
- elif filename.endswith('.gz'):
- save_type = SigFileSaveType.SIGFILE_GZ
- elif filename.endswith('.zip'):
- save_type = SigFileSaveType.ZIPFILE
- elif filename.endswith('.sqldb'):
- save_type = SigFileSaveType.SQLITEDB
- else:
- # default to SIGFILE intentionally!
- save_type = SigFileSaveType.SIGFILE
+ if progress is not None:
+ return progress.start_file(filename, loader)
else:
- save_type = force_type
-
- cls = _save_classes.get(save_type)
- if cls is None:
- raise Exception("invalid save type; this should never happen!?")
-
- return cls(filename)
+ return loader
diff --git a/tests/test_api.py b/tests/test_api.py
index 73f9ffc7a4..ccaf321df6 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -69,7 +69,7 @@ def test_load_fasta_as_signature():
# try loading a fasta file - should fail with informative exception
testfile = utils.get_test_data('short.fa')
- with pytest.raises(ValueError) as exc:
+ with pytest.raises(Exception) as exc:
idx = sourmash.load_file_as_index(testfile)
print(exc.value)
diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py
index e69a9b0c68..4a6672408e 100644
--- a/tests/test_index_protocol.py
+++ b/tests/test_index_protocol.py
@@ -75,7 +75,7 @@ def build_sbt_index_save_load(runtmp):
def build_zipfile_index(runtmp):
- from sourmash.sourmash_args import SaveSignatures_ZipFile
+ from sourmash.save_load import SaveSignatures_ZipFile
location = runtmp.output('index.zip')
with SaveSignatures_ZipFile(location) as save_sigs:
diff --git a/tests/test_plugin_framework.py b/tests/test_plugin_framework.py
new file mode 100644
index 0000000000..76b22bb730
--- /dev/null
+++ b/tests/test_plugin_framework.py
@@ -0,0 +1,261 @@
+"""
+Test the plugin framework in sourmash.plugins, which uses importlib.metadata
+entrypoints.
+
+CTB TODO:
+* check name?
+"""
+
+import pytest
+import sourmash
+from sourmash.logging import set_quiet
+
+import sourmash_tst_utils as utils
+from sourmash import plugins
+from sourmash.index import LinearIndex
+from sourmash.save_load import (Base_SaveSignaturesToLocation,
+ SaveSignaturesToLocation)
+
+
+class FakeEntryPoint:
+ """
+ A class that stores a name and an object to be returned on 'load()'.
+ Mocks the EntryPoint class used by importlib.metadata.
+ """
+ def __init__(self, name, load_obj):
+ self.name = name
+ self.load_obj = load_obj
+
+ def load(self):
+ return self.load_obj
+
+#
+# Test basic features of the load_from plugin hook.
+#
+
+class Test_EntryPointBasics_LoadFrom:
+ def get_some_sigs(self, location, *args, **kwargs):
+ ss2 = utils.get_test_data('2.fa.sig')
+ ss47 = utils.get_test_data('47.fa.sig')
+ ss63 = utils.get_test_data('63.fa.sig')
+
+ sig2 = sourmash.load_one_signature(ss2, ksize=31)
+ sig47 = sourmash.load_one_signature(ss47, ksize=31)
+ sig63 = sourmash.load_one_signature(ss63, ksize=31)
+
+ lidx = LinearIndex([sig2, sig47, sig63], location)
+
+ return lidx
+ get_some_sigs.priority = 1
+
+ def setup_method(self):
+ self.saved_plugins = plugins._plugin_load_from
+ plugins._plugin_load_from = [FakeEntryPoint('test_load', self.get_some_sigs)]
+
+ def teardown_method(self):
+ plugins._plugin_load_from = self.saved_plugins
+
+ def test_load_1(self):
+ ps = list(plugins.get_load_from_functions())
+ assert len(ps) == 1
+
+ def test_load_2(self, runtmp):
+ fake_location = runtmp.output('passed-through location')
+ idx = sourmash.load_file_as_index(fake_location)
+ print(idx, idx.location)
+
+ assert len(idx) == 3
+ assert idx.location == fake_location
+
+
+class Test_EntryPoint_LoadFrom_Priority:
+ def get_some_sigs(self, location, *args, **kwargs):
+ ss2 = utils.get_test_data('2.fa.sig')
+ ss47 = utils.get_test_data('47.fa.sig')
+ ss63 = utils.get_test_data('63.fa.sig')
+
+ sig2 = sourmash.load_one_signature(ss2, ksize=31)
+ sig47 = sourmash.load_one_signature(ss47, ksize=31)
+ sig63 = sourmash.load_one_signature(ss63, ksize=31)
+
+ lidx = LinearIndex([sig2, sig47, sig63], location)
+
+ return lidx
+ get_some_sigs.priority = 5
+
+ def set_called_flag_1(self, location, *args, **kwargs):
+ # high priority 1, raise ValueError
+ print('setting flag 1')
+ self.was_called_flag_1 = True
+ raise ValueError
+ set_called_flag_1.priority = 1
+
+ def set_called_flag_2(self, location, *args, **kwargs):
+ # high priority 2, return None
+ print('setting flag 2')
+ self.was_called_flag_2 = True
+
+ return None
+ set_called_flag_2.priority = 2
+
+ def set_called_flag_3(self, location, *args, **kwargs):
+ # lower priority 10, should not be called
+ print('setting flag 3')
+ self.was_called_flag_3 = True
+
+ return None
+ set_called_flag_3.priority = 10
+
+ def setup_method(self):
+ self.saved_plugins = plugins._plugin_load_from
+ plugins._plugin_load_from = [
+ FakeEntryPoint('test_load', self.get_some_sigs),
+ FakeEntryPoint('test_load_2', self.set_called_flag_1),
+ FakeEntryPoint('test_load_3', self.set_called_flag_2),
+ FakeEntryPoint('test_load_4', self.set_called_flag_3)
+ ]
+ self.was_called_flag_1 = False
+ self.was_called_flag_2 = False
+ self.was_called_flag_3 = False
+
+ def teardown_method(self):
+ plugins._plugin_load_from = self.saved_plugins
+
+ def test_load_1(self):
+ ps = list(plugins.get_load_from_functions())
+ assert len(ps) == 4
+
+ assert not self.was_called_flag_1
+ assert not self.was_called_flag_2
+ assert not self.was_called_flag_3
+
+ def test_load_2(self, runtmp):
+ fake_location = runtmp.output('passed-through location')
+ idx = sourmash.load_file_as_index(fake_location)
+ print(idx, idx.location)
+
+ assert len(idx) == 3
+ assert idx.location == fake_location
+
+ assert self.was_called_flag_1
+ assert self.was_called_flag_2
+ assert not self.was_called_flag_3
+
+
+#
+# Test basic features of the save_to plugin hook.
+#
+
+class FakeSaveClass(Base_SaveSignaturesToLocation):
+ """
+ A fake save class that just records what was sent to it.
+ """
+ priority = 50
+
+ def __init__(self, location):
+ super().__init__(location)
+ self.keep = []
+
+ @classmethod
+ def matches(cls, location):
+ if location:
+ return location.endswith('.this-is-a-test')
+
+ def add(self, ss):
+ super().add(ss)
+ self.keep.append(ss)
+
+
+class FakeSaveClass_HighPriority(FakeSaveClass):
+ priority = 1
+
+
+class Test_EntryPointBasics_SaveTo:
+ # test the basics
+ def setup_method(self):
+ self.saved_plugins = plugins._plugin_save_to
+ plugins._plugin_save_to = [FakeEntryPoint('test_save', FakeSaveClass)]
+
+ def teardown_method(self):
+ plugins._plugin_save_to = self.saved_plugins
+
+ def test_save_1(self):
+ ps = list(plugins.get_save_to_functions())
+ print(ps)
+ assert len(ps) == 1
+
+ def test_save_2(self, runtmp):
+ # load some signatures to save
+ ss2 = utils.get_test_data('2.fa.sig')
+ ss47 = utils.get_test_data('47.fa.sig')
+ ss63 = utils.get_test_data('63.fa.sig')
+
+ sig2 = sourmash.load_one_signature(ss2, ksize=31)
+ sig47 = sourmash.load_one_signature(ss47, ksize=31)
+ sig63 = sourmash.load_one_signature(ss63, ksize=31)
+
+ # build a fake location that matches the FakeSaveClass
+ # extension
+ fake_location = runtmp.output('out.this-is-a-test')
+
+ # this should use the plugin architecture to return an object
+ # of type FakeSaveClass, with the three signatures in it.
+ x = SaveSignaturesToLocation(fake_location)
+ with x as save_sig:
+ save_sig.add(sig2)
+ save_sig.add(sig47)
+ save_sig.add(sig63)
+
+ print(len(x))
+ print(x.keep)
+
+ assert isinstance(x, FakeSaveClass)
+ assert x.keep == [sig2, sig47, sig63]
+
+
+class Test_EntryPointPriority_SaveTo:
+ # test that priority is observed
+
+ def setup_method(self):
+ self.saved_plugins = plugins._plugin_save_to
+ plugins._plugin_save_to = [
+ FakeEntryPoint('test_save', FakeSaveClass),
+ FakeEntryPoint('test_save2', FakeSaveClass_HighPriority),
+ ]
+
+ def teardown_method(self):
+ plugins._plugin_save_to = self.saved_plugins
+
+ def test_save_1(self):
+ ps = list(plugins.get_save_to_functions())
+ print(ps)
+ assert len(ps) == 2
+
+ def test_save_2(self, runtmp):
+ # load some signatures to save
+ ss2 = utils.get_test_data('2.fa.sig')
+ ss47 = utils.get_test_data('47.fa.sig')
+ ss63 = utils.get_test_data('63.fa.sig')
+
+ sig2 = sourmash.load_one_signature(ss2, ksize=31)
+ sig47 = sourmash.load_one_signature(ss47, ksize=31)
+ sig63 = sourmash.load_one_signature(ss63, ksize=31)
+
+ # build a fake location that matches the FakeSaveClass
+ # extension
+ fake_location = runtmp.output('out.this-is-a-test')
+
+ # this should use the plugin architecture to return an object
+ # of type FakeSaveClass, with the three signatures in it.
+ x = SaveSignaturesToLocation(fake_location)
+ with x as save_sig:
+ save_sig.add(sig2)
+ save_sig.add(sig47)
+ save_sig.add(sig63)
+
+ print(len(x))
+ print(x.keep)
+
+ assert isinstance(x, FakeSaveClass_HighPriority)
+ assert x.keep == [sig2, sig47, sig63]
+ assert x.priority == 1
diff --git a/tests/test_sourmash_sketch.py b/tests/test_sourmash_sketch.py
index fc4cb2373a..c4f5ac14fe 100644
--- a/tests/test_sourmash_sketch.py
+++ b/tests/test_sourmash_sketch.py
@@ -235,7 +235,7 @@ def test_dna_multiple_ksize():
assert not params.hp
assert not params.protein
- from sourmash.sourmash_args import _get_signatures_from_rust
+ from sourmash.save_load import _get_signatures_from_rust
siglist = factory()
ksizes = set()