From 29e63851b2c571c7784448461098201744da64f8 Mon Sep 17 00:00:00 2001 From: james hadfield Date: Tue, 16 Apr 2024 22:00:01 +1200 Subject: [PATCH] WIP strain segment --- Snakefile | 95 ++++++++++++++++++++++++++--------- scripts/add_segment_counts.py | 65 ++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 25 deletions(-) create mode 100644 scripts/add_segment_counts.py diff --git a/Snakefile b/Snakefile index 118ac80..87057d6 100644 --- a/Snakefile +++ b/Snakefile @@ -1,6 +1,10 @@ -SUBTYPES = ["h5nx","h5n1","h9n2","h7n9"] +# SUBTYPES = ["h5nx","h5n1","h9n2","h7n9"] +SUBTYPES = ["h5nx"] SEGMENTS = ["pb2", "pb1", "pa", "ha","np", "na", "mp", "ns"] +# Run snakemake ... --config 'same_strains_per_segment=True' to use the strains for all segments (within a given subtype) +SAME_STRAINS = bool(config.get('same_strains_per_segment', False)) + path_to_fauna = '../fauna' rule all: @@ -24,14 +28,18 @@ def download_by(w): return(db[w.subtype]) def metadata_by_wildcards(w): + # TODO XXX my stuff won't work for h7n9, h9n2 + # rules.add_segment_sequence_counts.output.metadata if (SAME_STRAINS and w.segment=='ha') else rules.parse.output.metadata, md = {"h5n1": rules.add_h5_clade.output.metadata, "h5nx": rules.add_h5_clade.output.metadata, "h7n9": rules.parse.output.metadata, "h9n2": rules.parse.output.metadata} return(md[w.subtype]) def group_by(w): + return 'region year' gb = {'h5nx': 'subtype country year','h5n1': 'region country year', 'h7n9': 'division year', 'h9n2': 'country year'} return gb[w.subtype] def sequences_per_group(w): + return '1' # TODO XXX spg = {'h5nx': '5','h5n1': '10', 'h7n9': '70', 'h9n2': '10'} return spg[w.subtype] @@ -86,10 +94,28 @@ rule parse: --prettify-fields {params.prettify_fields} """ +# For the HA metadata file (for each subtype) add a column "n_segments" +# which reports how many segments have sequence data (no QC performed) +# This will force the download & parsing of all segments for a given subtype +rule add_segment_sequence_counts: + input: + segments = expand("results/metadata_{{subtype}}_{segment}.tsv", segment=SEGMENTS), + metadata = "results/metadata_{subtype}_ha.tsv", + output: + metadata = "results/metadata-segments_{subtype}_ha.tsv" + shell: + """ + python scripts/add_segment_counts.py \ + --segments {input.segments} \ + --metadata {input.metadata} \ + --output {output.metadata} + """ + rule add_h5_clade: message: "Adding in a column for h5 clade numbering" input: - metadata = rules.parse.output.metadata, + # metadata = rules.parse.output.metadata, + metadata = lambda w: rules.add_segment_sequence_counts.output.metadata if (SAME_STRAINS and w.segment=='ha') else rules.parse.output.metadata, clades_file = files.clades_file output: metadata= "results/metadata-with-clade_{subtype}_{segment}.tsv" @@ -101,41 +127,60 @@ rule add_h5_clade: --clades {input.clades_file} """ +def _filter_params(wildcards, input, output, threads, resources): + """ + Generate the arguments to `augur filter`. When we are running independent analyses + (i.e. not using the same strains for each segment), then we generate a full set of + filter parameters here. + When we are using the same sequences for each segment, then for HA we use a full + filter call and for the rest of the segments we filter to the strains chosen for HA + + NOTE: we could move the functions `group_by` etc into this function if that was + clearer. + """ + # if input.strains then we just restrict to those + if input.strains: + return f"--exclude-all --include {input.strains}" + + # If SAME_STRAINS then we also want to filter to strains with all 8 segments + segments = f"n_segments!={len(SEGMENTS)}" if SAME_STRAINS else ''; + + # formulate our typical filtering parameters + cmd = f" --group-by {group_by(wildcards)}" + cmd += f" --sequences-per-group {sequences_per_group(wildcards)}" + cmd += f" --min-date {min_date(wildcards)}" + cmd += f" --exclude-where host=laboratoryderived host=ferret host=unknown host=other country=? region=? {segments}" + cmd += f" --min-length {min_length(wildcards)}" + cmd += f" --non-nucleotide" + return cmd + rule filter: - message: - """ - Filtering to - - {params.sequences_per_group} sequence(s) per {params.group_by!s} - - excluding strains in {input.exclude} - - samples with missing region and country metadata - - excluding strains prior to {params.min_date} - """ + """ + Filtering to + - {params.sequences_per_group} sequence(s) per {params.group_by!s} + - excluding strains in {input.exclude} + - samples with missing region and country metadata + - excluding strains prior to {params.min_date} + """ input: sequences = rules.parse.output.sequences, metadata = metadata_by_wildcards, - exclude = files.dropped_strains + exclude = files.dropped_strains, + strains = lambda w: f"results/filtered_{w.subtype}_ha.txt" if (SAME_STRAINS and w.segment!='ha') else [], output: - sequences = "results/filtered_{subtype}_{segment}.fasta" + sequences = "results/filtered_{subtype}_{segment}.fasta", + strains = "results/filtered_{subtype}_{segment}.txt", params: - group_by = group_by, - sequences_per_group = sequences_per_group, - min_date = min_date, - min_length = min_length, - exclude_where = "host=laboratoryderived host=ferret host=unknown host=other country=? region=?" - + args = _filter_params, shell: """ augur filter \ --sequences {input.sequences} \ --metadata {input.metadata} \ --exclude {input.exclude} \ - --output {output.sequences} \ - --group-by {params.group_by} \ - --sequences-per-group {params.sequences_per_group} \ - --min-date {params.min_date} \ - --exclude-where {params.exclude_where} \ - --min-length {params.min_length} \ - --non-nucleotide + --output-sequences {output.sequences} \ + --output-strains {output.strains} \ + {params.args} """ rule align: diff --git a/scripts/add_segment_counts.py b/scripts/add_segment_counts.py new file mode 100644 index 0000000..27184fd --- /dev/null +++ b/scripts/add_segment_counts.py @@ -0,0 +1,65 @@ +""" +Takes in a set of metadata TSVs corresponding to segments (i.e. typically 8 TSVs) +and adds a column to the input `--metadata` TSV with the number of segments +that strain appears in. +""" + +import argparse +import csv +from collections import defaultdict + +def collect_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--segments', type=str, nargs='+', help='Metadata TSVs for all segments') + parser.add_argument('--metadata', type=str, help='Metadata TSV which will be amended for output. Must also appear in --segments.') + parser.add_argument('--output', type=str, help='metadata file name') + return parser.parse_args() + +def read_metadata(fname, strains_only=False): + strains = set() + rows = [] + with open(fname) as csvfile: + reader = csv.DictReader(csvfile, delimiter='\t') + for row in reader: + strains.add(row['strain']) + if not strains_only: + rows.append(row) + if strains_only: + return strains + return (strains, reader.fieldnames, rows) + +def summary(strain_count): + ## Print some basic stats! + print("Num strains observed (across all segments): ", len(strain_count.keys())) + counts = [0]*9 # 1-indexed + for n in strain_count.values(): + counts[n]+=1 + for i in range(1,9): + print(f"Num strains observed in {i} segments: ", counts[i]) + + +if __name__=="__main__": + args = collect_args() + # strains_per_segment = [] + strain_count = defaultdict(int) + for fname in args.segments: + if fname==args.metadata: + _strains, fieldnames, rows = read_metadata(fname) + else: + _strains = read_metadata(fname, strains_only=True) + for s in _strains: + strain_count[s]+=1 + summary(strain_count) + + # append count to data for output + column = "n_segments" + fieldnames.append(column) + for row in rows: + row[column]=strain_count[row['strain']] + + with open(args.output, 'w') as fh: + writer = csv.DictWriter(fh, fieldnames=fieldnames, delimiter='\t') + writer.writeheader() + for row in rows: + writer.writerow(row) + print("Output written to", args.output)