Skip to content

Commit

Permalink
WIP strain segment
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshadfield committed Apr 16, 2024
1 parent 018f454 commit 29e6385
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 25 deletions.
95 changes: 70 additions & 25 deletions Snakefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
SUBTYPES = ["h5nx","h5n1","h9n2","h7n9"]
# SUBTYPES = ["h5nx","h5n1","h9n2","h7n9"]
SUBTYPES = ["h5nx"]
SEGMENTS = ["pb2", "pb1", "pa", "ha","np", "na", "mp", "ns"]

# Run snakemake ... --config 'same_strains_per_segment=True' to use the strains for all segments (within a given subtype)
SAME_STRAINS = bool(config.get('same_strains_per_segment', False))

path_to_fauna = '../fauna'

rule all:
Expand All @@ -24,14 +28,18 @@ def download_by(w):
return(db[w.subtype])

def metadata_by_wildcards(w):
# TODO XXX my stuff won't work for h7n9, h9n2
# rules.add_segment_sequence_counts.output.metadata if (SAME_STRAINS and w.segment=='ha') else rules.parse.output.metadata,
md = {"h5n1": rules.add_h5_clade.output.metadata, "h5nx": rules.add_h5_clade.output.metadata, "h7n9": rules.parse.output.metadata, "h9n2": rules.parse.output.metadata}
return(md[w.subtype])

def group_by(w):
return 'region year'
gb = {'h5nx': 'subtype country year','h5n1': 'region country year', 'h7n9': 'division year', 'h9n2': 'country year'}
return gb[w.subtype]

def sequences_per_group(w):
return '1' # TODO XXX
spg = {'h5nx': '5','h5n1': '10', 'h7n9': '70', 'h9n2': '10'}
return spg[w.subtype]

Expand Down Expand Up @@ -86,10 +94,28 @@ rule parse:
--prettify-fields {params.prettify_fields}
"""

# For the HA metadata file (for each subtype) add a column "n_segments"
# which reports how many segments have sequence data (no QC performed)
# This will force the download & parsing of all segments for a given subtype
rule add_segment_sequence_counts:
input:
segments = expand("results/metadata_{{subtype}}_{segment}.tsv", segment=SEGMENTS),
metadata = "results/metadata_{subtype}_ha.tsv",
output:
metadata = "results/metadata-segments_{subtype}_ha.tsv"
shell:
"""
python scripts/add_segment_counts.py \
--segments {input.segments} \
--metadata {input.metadata} \
--output {output.metadata}
"""

rule add_h5_clade:
message: "Adding in a column for h5 clade numbering"
input:
metadata = rules.parse.output.metadata,
# metadata = rules.parse.output.metadata,
metadata = lambda w: rules.add_segment_sequence_counts.output.metadata if (SAME_STRAINS and w.segment=='ha') else rules.parse.output.metadata,
clades_file = files.clades_file
output:
metadata= "results/metadata-with-clade_{subtype}_{segment}.tsv"
Expand All @@ -101,41 +127,60 @@ rule add_h5_clade:
--clades {input.clades_file}
"""

def _filter_params(wildcards, input, output, threads, resources):
"""
Generate the arguments to `augur filter`. When we are running independent analyses
(i.e. not using the same strains for each segment), then we generate a full set of
filter parameters here.
When we are using the same sequences for each segment, then for HA we use a full
filter call and for the rest of the segments we filter to the strains chosen for HA
NOTE: we could move the functions `group_by` etc into this function if that was
clearer.
"""
# if input.strains then we just restrict to those
if input.strains:
return f"--exclude-all --include {input.strains}"

# If SAME_STRAINS then we also want to filter to strains with all 8 segments
segments = f"n_segments!={len(SEGMENTS)}" if SAME_STRAINS else '';

# formulate our typical filtering parameters
cmd = f" --group-by {group_by(wildcards)}"
cmd += f" --sequences-per-group {sequences_per_group(wildcards)}"
cmd += f" --min-date {min_date(wildcards)}"
cmd += f" --exclude-where host=laboratoryderived host=ferret host=unknown host=other country=? region=? {segments}"
cmd += f" --min-length {min_length(wildcards)}"
cmd += f" --non-nucleotide"
return cmd

rule filter:
message:
"""
Filtering to
- {params.sequences_per_group} sequence(s) per {params.group_by!s}
- excluding strains in {input.exclude}
- samples with missing region and country metadata
- excluding strains prior to {params.min_date}
"""
"""
Filtering to
- {params.sequences_per_group} sequence(s) per {params.group_by!s}
- excluding strains in {input.exclude}
- samples with missing region and country metadata
- excluding strains prior to {params.min_date}
"""
input:
sequences = rules.parse.output.sequences,
metadata = metadata_by_wildcards,
exclude = files.dropped_strains
exclude = files.dropped_strains,
strains = lambda w: f"results/filtered_{w.subtype}_ha.txt" if (SAME_STRAINS and w.segment!='ha') else [],
output:
sequences = "results/filtered_{subtype}_{segment}.fasta"
sequences = "results/filtered_{subtype}_{segment}.fasta",
strains = "results/filtered_{subtype}_{segment}.txt",
params:
group_by = group_by,
sequences_per_group = sequences_per_group,
min_date = min_date,
min_length = min_length,
exclude_where = "host=laboratoryderived host=ferret host=unknown host=other country=? region=?"

args = _filter_params,
shell:
"""
augur filter \
--sequences {input.sequences} \
--metadata {input.metadata} \
--exclude {input.exclude} \
--output {output.sequences} \
--group-by {params.group_by} \
--sequences-per-group {params.sequences_per_group} \
--min-date {params.min_date} \
--exclude-where {params.exclude_where} \
--min-length {params.min_length} \
--non-nucleotide
--output-sequences {output.sequences} \
--output-strains {output.strains} \
{params.args}
"""

rule align:
Expand Down
65 changes: 65 additions & 0 deletions scripts/add_segment_counts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Takes in a set of metadata TSVs corresponding to segments (i.e. typically 8 TSVs)
and adds a column to the input `--metadata` TSV with the number of segments
that strain appears in.
"""

import argparse
import csv
from collections import defaultdict

def collect_args():
parser = argparse.ArgumentParser()
parser.add_argument('--segments', type=str, nargs='+', help='Metadata TSVs for all segments')
parser.add_argument('--metadata', type=str, help='Metadata TSV which will be amended for output. Must also appear in --segments.')
parser.add_argument('--output', type=str, help='metadata file name')
return parser.parse_args()

def read_metadata(fname, strains_only=False):
strains = set()
rows = []
with open(fname) as csvfile:
reader = csv.DictReader(csvfile, delimiter='\t')
for row in reader:
strains.add(row['strain'])
if not strains_only:
rows.append(row)
if strains_only:
return strains
return (strains, reader.fieldnames, rows)

def summary(strain_count):
## Print some basic stats!
print("Num strains observed (across all segments): ", len(strain_count.keys()))
counts = [0]*9 # 1-indexed
for n in strain_count.values():
counts[n]+=1
for i in range(1,9):
print(f"Num strains observed in {i} segments: ", counts[i])


if __name__=="__main__":
args = collect_args()
# strains_per_segment = []
strain_count = defaultdict(int)
for fname in args.segments:
if fname==args.metadata:
_strains, fieldnames, rows = read_metadata(fname)
else:
_strains = read_metadata(fname, strains_only=True)
for s in _strains:
strain_count[s]+=1
summary(strain_count)

# append count to data for output
column = "n_segments"
fieldnames.append(column)
for row in rows:
row[column]=strain_count[row['strain']]

with open(args.output, 'w') as fh:
writer = csv.DictWriter(fh, fieldnames=fieldnames, delimiter='\t')
writer.writeheader()
for row in rows:
writer.writerow(row)
print("Output written to", args.output)

0 comments on commit 29e6385

Please sign in to comment.