diff --git a/augur/merge.py b/augur/merge.py index 616174691..ce8f57928 100644 --- a/augur/merge.py +++ b/augur/merge.py @@ -79,6 +79,9 @@ _n = gettext.NullTranslations().ngettext +SEQUENCE_ID_COLUMN = 'id' + + class Database: fd: int """Database file descriptor""" @@ -120,14 +123,14 @@ class NamedFile: path: str -class NamedSequences(NamedFile): +class NamedSequenceFile(NamedFile): def __init__(self, name: str, path: str): self.name = name self.path = path self.table_name = f"sequences_{self.name}" def __repr__(self): - return f"" + return f"" class NamedMetadata(Metadata, NamedFile): def __init__(self, name: str, *args, **kwargs): @@ -198,9 +201,14 @@ def run(args): sequences = get_sequences(args.sequences) load_sequences(db, sequences) - # FIXME: check that entries in input match - # WARNING: Sequence 'XXX' in a.tsv is missing from a.fasta. It will not be present in any output. - # WARNING: Sequence 'YYY' in b.fasta is missing from b.csv. It will not be present in any output. + metadata_by_name = {m.name: m for m in metadata} + sequences_by_name = {s.name: s for s in sequences} + + for name in metadata_by_name.keys() & sequences_by_name.keys(): + # FIXME: check that entries in input match + # WARNING: Sequence 'XXX' in a.tsv is missing from a.fasta. It will not be present in any output. + # WARNING: Sequence 'YYY' in b.fasta is missing from b.csv. It will not be present in any output. + ... if args.output_metadata: merge_metadata(db, metadata, output_columns, args.output_metadata) @@ -378,19 +386,25 @@ def merge_metadata( def get_sequences(input_sequences): # Validate arguments sequences = parse_named_inputs(input_sequences) + + # FIXME: support unnamed inputs + # table name can be based on filename + random characters - return [NamedSequences(name, path) for name, path in sequences] + return [NamedSequenceFile(name, path) for name, path in sequences] + +def load_sequences(db: Database, sequences: List[NamedSequenceFile]): + for s in sequences: + ids = [seq.id for seq in read_sequences(s.path)] -def load_sequences(db: Database, sequences: List[NamedSequences]): - for sequence_input in sequences: - ids = [seq.id for seq in read_sequences(sequence_input.path)] + if duplicates := [item for item, count in count_unique(ids) if count > 1]: + raise AugurError(f"The following entries are duplicated in {s.path!r}:\n" + "\n".join(duplicates)) - # FIXME: error on duplicates within a single sequence file + db.run(f"create table {sqlite_quote_id(s.table_name)} ({sqlite_quote_id(SEQUENCE_ID_COLUMN)} text);") + values = ", ".join([f"('{id}')" for id in ids]) + db.run(f"insert into {sqlite_quote_id(s.table_name)} ({sqlite_quote_id(SEQUENCE_ID_COLUMN)}) values {values};") - db.run(f"CREATE TABLE {sequence_input.table_name} (id TEXT)") - values = ", ".join([f"('{seq_id}')" for seq_id in ids]) - db.run(f"INSERT INTO {sequence_input.table_name} (id) VALUES {values}") + db.run(f'create unique index {sqlite_quote_id(f"{s.table_name}_id")} on {sqlite_quote_id(s.table_name)}({sqlite_quote_id(SEQUENCE_ID_COLUMN)});') # FIXME: return a list of arguments and don't use shell @@ -409,7 +423,7 @@ def cat(filepath: str): def merge_sequences( db: Database, - sequences: List[NamedSequences], + sequences: List[NamedSequenceFile], output_sequences: str, ): # Confirm that seqkit is installed.