Skip to content

Commit

Permalink
merge: Support sequences
Browse files Browse the repository at this point in the history
Add sequence support in addition to the existing metadata support.
SeqKit is used to deduplicate across sequence files. Duplicates within
an individual sequence file are not supported. Those are checked by
reading IDs using read_sequences.
  • Loading branch information
victorlin committed Oct 23, 2024
1 parent 518483b commit 34dfb46
Show file tree
Hide file tree
Showing 3 changed files with 354 additions and 7 deletions.
250 changes: 243 additions & 7 deletions augur/merge.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""
Merge two or more metadata tables into one.
Merge two or more datasets into one.
Tables must be given unique names to identify them in the output and are
merged in the order given.
Datasets can consist of metadata and/or sequence files.
Metadata
========
Metadata tables must be given unique names to identify them in the output and
are merged in the order given.
Rows are joined by id (e.g. "strain" or "name" or other
--metadata-id-columns), and ids must be unique within an input table (i.e.
Expand Down Expand Up @@ -34,6 +39,20 @@
future. The SQLite 3 CLI, sqlite3, must be available. If it's not on PATH (or
you want to use a version different from what's on PATH), set the SQLITE3
environment variable to path of the desired sqlite3 executable.
Sequences
=========
Sequence files are merged in the order given. Naming files is optional. If
names are provided, they will be checked against any named metadata for order
and IDs.
SeqKit is used behind the scenes to implement the merge, but, at least for now,
this should be considered an implementation detail that may change in the
future. The CLI program seqkit must be available. If it's not on PATH (or
you want to use a version different from what's on PATH), set the SEQKIT
environment variable to path of the desired seqkit executable.
"""
import argparse
import os
Expand All @@ -44,14 +63,15 @@
from itertools import starmap
from shlex import quote as shquote
from shutil import which
from tempfile import mkstemp
from tempfile import mkstemp, NamedTemporaryFile
from textwrap import dedent
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union

from augur.argparse_ import ExtendOverwriteDefault, SKIP_AUTO_DEFAULT_IN_HELP
from augur.errors import AugurError
from augur.io.metadata import DEFAULT_DELIMITERS, DEFAULT_ID_COLUMNS, Metadata
from augur.io.print import print_err, print_debug, _n
from augur.io.sequences import read_sequences
from augur.utils import first_line


Expand All @@ -74,6 +94,9 @@
augur = f"augur"


SEQUENCE_ID_COLUMN = 'id'


class Database:
fd: int
"""Database file descriptor"""
Expand Down Expand Up @@ -105,13 +128,41 @@ def cleanup(self):
os.unlink(self.path)


class UnnamedFile:
table_name: str
"""Generated SQLite table name for this file, based on *path*."""

path: str


class NamedFile:
name: str
"""User-provided descriptive name for this file."""

table_name: str
"""Generated SQLite table name for this file, based on *name*."""

path: str


class NamedSequenceFile(NamedFile):
def __init__(self, name: str, path: str):
self.name = name
self.path = path
self.table_name = f"sequences_{self.name}"

def __repr__(self):
return f"<NamedSequenceFile {self.name}={self.path}>"


class UnnamedSequenceFile(UnnamedFile):
def __init__(self, path: str):
self.path = path
self.table_name = f"sequences_{re.sub(r'[^a-zA-Z0-9]', '_', os.path.basename(self.path))}"

def __repr__(self):
return f"<NamedSequenceFile {self.name}={self.path}>"


class NamedMetadata(Metadata, NamedFile):
def __init__(self, name: str, *args, **kwargs):
Expand All @@ -128,6 +179,7 @@ def register_parser(parent_subparsers):

input_group = parser.add_argument_group("inputs", "options related to input")
input_group.add_argument("--metadata", nargs="+", action="extend", metavar="NAME=FILE", help="Required. Metadata table names and file paths. Names are arbitrary monikers used solely for referring to the associated input file in other arguments and in output column names. Paths must be to seekable files, not unseekable streams. Compressed files are supported." + SKIP_AUTO_DEFAULT_IN_HELP)
input_group.add_argument("--sequences", nargs="+", action="extend", metavar="[NAME=]FILE", help="Sequence files, optionally named for validation with named metadata. Compressed files are supported." + SKIP_AUTO_DEFAULT_IN_HELP)

input_group.add_argument("--metadata-id-columns", default=DEFAULT_ID_COLUMNS, nargs="+", action=ExtendOverwriteDefault, metavar="[TABLE=]COLUMN", help=f"Possible metadata column names containing identifiers, considered in the order given. Columns will be considered for all metadata tables by default. Table-specific column names may be given using the same names assigned in --metadata. Only one ID column will be inferred for each table. (default: {' '.join(map(shquote_humanized, DEFAULT_ID_COLUMNS))})" + SKIP_AUTO_DEFAULT_IN_HELP)
input_group.add_argument("--metadata-delimiters", default=DEFAULT_DELIMITERS, nargs="+", action=ExtendOverwriteDefault, metavar="[TABLE=]CHARACTER", help=f"Possible field delimiters to use for reading metadata tables, considered in the order given. Delimiters will be considered for all metadata tables by default. Table-specific delimiters may be given using the same names assigned in --metadata. Only one delimiter will be inferred for each table. (default: {' '.join(map(shquote_humanized, DEFAULT_DELIMITERS))})" + SKIP_AUTO_DEFAULT_IN_HELP)
Expand All @@ -136,23 +188,26 @@ def register_parser(parent_subparsers):
output_group.add_argument('--output-metadata', metavar="FILE", help="Required. Merged metadata as TSV. Compressed files are supported." + SKIP_AUTO_DEFAULT_IN_HELP)
output_group.add_argument('--source-columns', metavar="TEMPLATE", help=f"Template with which to generate names for the columns (described above) identifying the source of each row's data. Must contain a literal placeholder, {{NAME}}, which stands in for the metadata table names assigned in --metadata. (default: disabled)" + SKIP_AUTO_DEFAULT_IN_HELP)
output_group.add_argument('--no-source-columns', dest="source_columns", action="store_const", const=None, help=f"Suppress generated columns (described above) identifying the source of each row's data. This is the default behaviour, but it may be made explicit or used to override a previous --source-columns." + SKIP_AUTO_DEFAULT_IN_HELP)
output_group.add_argument('--output-sequences', metavar="FILE", help="Required. Merged sequences as FASTA. Compressed files are supported." + SKIP_AUTO_DEFAULT_IN_HELP)
output_group.add_argument('--quiet', action="store_true", default=False, help="Suppress informational and warning messages normally written to stderr. (default: disabled)" + SKIP_AUTO_DEFAULT_IN_HELP)

return parser


def validate_arguments(args):
# These will make more sense when sequence support is added.
if not args.metadata:
if not any((args.metadata, args.sequences)):
raise AugurError("At least one input must be specified.")
if not args.output_metadata:
if not any((args.output_metadata, args.output_sequences)):
raise AugurError("At least one output must be specified.")

if args.metadata and not len(args.metadata) >= 2:
raise AugurError(f"At least two metadata inputs are required for merging.")

if args.output_metadata and not args.metadata:
raise AugurError("--output-metadata requires --metadata.")
if args.output_sequences and not args.sequences:
raise AugurError("--output-sequences requires --sequences.")


def run(args: argparse.Namespace):
Expand All @@ -168,16 +223,83 @@ def run(args: argparse.Namespace):
metadata: Optional[List[NamedMetadata]] = None
output_columns: Optional[Columns] = None
output_source_column: Optional[Callable[[str], str]] = None
sequences: Optional[List[Union[NamedSequenceFile, UnnamedSequenceFile]]] = None
named_sequences: Optional[List[NamedSequenceFile]] = None

if args.metadata:
metadata = get_metadata(args.metadata, args.metadata_id_columns, args.metadata_delimiters)

if args.sequences:
sequences = list(get_sequences(args.sequences))

# Perform checks on file names.

named_sequences = [s for s in sequences if isinstance(s, NamedSequenceFile)]

if unnamed_sequences := [s for s in sequences if isinstance(s, UnnamedSequenceFile)]:
for x in unnamed_sequences:
print_info(f"WARNING: Sequence file {x.path!r} is unnamed. Skipping validation with metadata.")

if metadata and named_sequences:
metadata_order = [m.name for m in metadata]
sequences_order = [s.name for s in named_sequences]

if metadata_order != sequences_order:
raise AugurError(f"Order of inputs differs between named metadata {metadata_order!r} and named sequences {sequences_order!r}.")

# FIXME: add easy clarification that requires a few more conditions:
# ERROR: Sequence file 'c=c.fasta' does not have a corresponding metadata file.


# Load data.

if metadata:
load_metadata(db, metadata)
if sequences:
load_sequences(db, sequences)


# Perform checks on file contents.

if metadata and named_sequences:
metadata_by_name = {m.name: m for m in metadata}
sequences_by_name = {s.name: s for s in named_sequences}

for name in sorted(metadata_by_name.keys() & sequences_by_name.keys()):
m = metadata_by_name[name]
s = sequences_by_name[name]

# FIXME: import this at the top-level, taking into account the existing sqlite3 function at that level
import sqlite3
with sqlite3.connect(db.path) as connection:
connection.row_factory = sqlite3.Row

metadata_ids = {x[m.id_column] for x in
connection.execute(f"""select {sqlite_quote_id(m.id_column)}
from {sqlite_quote_id(m.table_name)}
""")}

sequence_ids = {x[SEQUENCE_ID_COLUMN] for x in
connection.execute(f"""select {sqlite_quote_id(SEQUENCE_ID_COLUMN)}
from {sqlite_quote_id(s.table_name)}
""")}

for x in sorted(metadata_ids - sequence_ids):
print_info(f"WARNING: Sequence {x!r} in {m.path!r} is missing from {s.path!r}. Outputs may continue to be mismatched.")
for x in sorted(sequence_ids - metadata_ids):
print_info(f"WARNING: Sequence {x!r} in {s.path!r} is missing from {m.path!r}. Outputs may continue to be mismatched.")


# Write outputs.

if args.output_metadata:
output_source_column = get_output_source_column(args.source_columns, metadata)
output_columns = get_output_columns(metadata)
merge_metadata(db, metadata, output_columns, args.output_metadata, output_source_column)

if args.output_sequences:
merge_sequences(sequences, args.output_sequences)


def get_metadata(
input_metadata: Sequence[str],
Expand Down Expand Up @@ -407,6 +529,89 @@ def merge_metadata(
db.cleanup()


def get_sequences(input_sequences: List[str]):
try:
sequences = parse_inputs(input_sequences)
except InvalidNamedInputError as e:
raise AugurError(dedent(f"""\
Input filenames cannot start with '='.
The following {_n("input starts", "inputs start", len(e.invalid))} with '=':
{indented_list([repr(x) for x in e.invalid], ' ' + ' ')}
""")) from e
except DuplicateInputNameError as e:
raise AugurError(dedent(f"""\
Sequence input names must be unique.
The following {_n("name was", "names were", len(e.duplicates))} used more than once:
{indented_list([repr(x) for x in e.duplicates], ' ' + ' ')}
""")) from e

for name, path in sequences:
if name == "":
yield UnnamedSequenceFile(path)
else:
yield NamedSequenceFile(name, path)


def load_sequences(db: Database, sequences: List[Union[NamedSequenceFile, UnnamedSequenceFile]]):
for s in sequences:
print_info(f"Reading sequence IDs from {s.path!r}…")
ids = [seq.id for seq in read_sequences(s.path)]

if duplicates := [item for item, count in count_unique(ids) if count > 1]:
raise AugurError(f"The following entries are duplicated in {s.path!r}:\n" + "\n".join(duplicates))

# FIXME: Skip for unnamed sequences? This is only used for named sequences to check against paired metadata.
db.run(f"create table {sqlite_quote_id(s.table_name)} ({sqlite_quote_id(SEQUENCE_ID_COLUMN)} text);")
values = ", ".join([f"('{id}')" for id in ids])
db.run(f"insert into {sqlite_quote_id(s.table_name)} ({sqlite_quote_id(SEQUENCE_ID_COLUMN)}) values {values};")

db.run(f'create unique index {sqlite_quote_id(f"{s.table_name}_id")} on {sqlite_quote_id(s.table_name)}({sqlite_quote_id(SEQUENCE_ID_COLUMN)});')


def cat(filepath: str):
cat = "cat"

if filepath.endswith(".gz"):
cat = "gzcat"
if filepath.endswith(".xz"):
cat = "xzcat"
if filepath.endswith(".zst"):
cat = "zstdcat"

return cat, filepath


def merge_sequences(
sequences: List[Union[NamedSequenceFile, UnnamedSequenceFile]],
output_sequences: str,
):
# Confirm that seqkit is installed.
if which("seqkit") is None:
raise AugurError("'seqkit' is not installed! This is required to merge sequences.")

with NamedTemporaryFile() as temp_file:
with open(temp_file.name, 'w') as f:
# Reversed because seqkit rmdup keeps the first entry but this command
# should keep the last entry.
for s in reversed(sequences):
print_info(f"Reading sequences from {s.path!r}…")
subprocess.Popen(cat(s.path), stdout=f)

print_info(f"Merging sequences and writing to {output_sequences!r}…")
process = seqkit('rmdup', temp_file.name, stdout=subprocess.PIPE)

# FIXME: handle `-` better
if output_sequences == "-":
sys.stdout.write(process.stdout)
else:
with open(output_sequences, "w") as f:
f.write(process.stdout)


def sqlite3(*args, **kwargs):
"""
Internal helper for invoking ``sqlite3``, the SQLite CLI program.
Expand Down Expand Up @@ -477,6 +682,37 @@ def sqlite_quote_string(x):
return "'" + x.replace("'", "''") + "'"


def seqkit(*args, **kwargs):
"""
Internal helper for invoking ``seqkit``, the SeqKit CLI program.
"""
seqkit = os.environ.get("SEQKIT", which("seqkit"))

if not seqkit:
raise AugurError(dedent(f"""\
Unable to find the program `seqkit`. Is it installed?
In order to use `augur merge`, the SeqKit CLI must be installed
separately. It is typically provided by a Nextstrain runtime.
"""))

argv = [seqkit, *args]

print_debug(f"running {argv!r}")
proc = subprocess.run(argv, encoding="utf-8", text=True, **kwargs)

try:
proc.check_returncode()
except subprocess.CalledProcessError as err:
raise SeqKitError(f"seqkit invocation failed") from err

return proc


class SeqKitError(Exception):
pass


def pairs(xs: Iterable[str]) -> Iterable[Tuple[str, str]]:
"""
Split an iterable of ``k=v`` strings into an iterable of ``(k,v)`` tuples.
Expand Down
Loading

0 comments on commit 34dfb46

Please sign in to comment.