From 9b04b3cc5ebf71a4372555a88974d4a2535a213d Mon Sep 17 00:00:00 2001 From: David Lougheed Date: Mon, 23 Sep 2024 20:12:28 -0400 Subject: [PATCH] feat(call): correctly alter/remove alt anchor base in VCF output --- strkit/call/call_locus.py | 32 ++++++++++++++++++++++++-------- strkit/call/output/vcf.py | 39 ++++++++++++++++++++++++++++++++------- strkit/call/types.py | 5 ++++- 3 files changed, 60 insertions(+), 16 deletions(-) diff --git a/strkit/call/call_locus.py b/strkit/call/call_locus.py index 5b6014f..300e976 100644 --- a/strkit/call/call_locus.py +++ b/strkit/call/call_locus.py @@ -633,6 +633,7 @@ def call_alleles_with_incorporated_snvs( # Cluster reads together using the distance matrix, which incorporates SNV and possibly copy number information. cluster_labels, cluster_indices = _agg_clust_alleles_by_dm(n_alleles, dm) + del dm cluster_reads: list[tuple[ReadDict, ...]] = [] cns: list[NDArray[np.int32]] = [] @@ -1276,6 +1277,17 @@ def get_read_length_partition_mean(p_idx: int) -> float: # TODO: re-examine weighting to possibly incorporate chance of drawing read large enough read_weight = (mean_containing_size + tr_len_w_flank - 2) / (mean_containing_size - tr_len_w_flank + 1) + # --- + + read_start_anchor: str = "" + if consensus: + anchor_pair_idx, anchor_pair_found = find_pair_by_ref_pos(r_coords, left_coord_adj - 1, 0) + if anchor_pair_found: + read_start_anchor = qs[q_coords[anchor_pair_idx]:left_flank_end] + # otherwise, leave as blank - anchor base deleted + + # --- + crs_cir = chimeric_read_status[rn] == 3 # Chimera within the TR region, indicating a potential large expansion read_dict[rn] = read_dict_entry = { "s": "-" if segment.is_reverse else "+", @@ -1288,7 +1300,7 @@ def get_read_length_partition_mean(p_idx: int) -> float: read_dict_extra[rn] = read_extra_entry = { "_ref_start": segment_start, "_ref_end": segment_end, - **({"_tr_seq": tr_read_seq} if consensus else {}), + **({"_start_anchor": read_start_anchor, "_tr_seq": tr_read_seq} if consensus else {}), } # Reads can show up more than once - TODO - cache this information across loci @@ -1342,8 +1354,7 @@ def get_read_length_partition_mean(p_idx: int) -> float: n_reads_in_dict: int = len(read_dict) locus_result.update({ - # TODO: alt anchors: - **({"ref_start_anchor": ref_left_flank_seq[-1], "ref_seq": ref_seq} if consensus else {}), + **({"ref_start_anchor": ref_left_flank_seq[-1].upper(), "ref_seq": ref_seq} if consensus else {}), "reads": read_dict, }) @@ -1506,7 +1517,10 @@ def get_read_length_partition_mean(p_idx: int) -> float: # don't know how re-sampling has occurred. call_peak_n_reads: list[int] = [] peak_kmers: list[Counter] = [Counter() for _ in range(call_modal_n or 0)] + call_seqs: list[tuple[str, ConsensusMethod]] = [] + call_anchor_seqs: list[tuple[str, ConsensusMethod]] = [] + if read_peaks_called := call_modal_n and call_modal_n <= 2: peaks: NDArray[np.float_] = call_peaks[:call_modal_n] stdevs: NDArray[np.float_] = call_stdevs[:call_modal_n] @@ -1576,16 +1590,18 @@ def get_read_length_partition_mean(p_idx: int) -> float: call_99_cis = None if call_data and consensus: - call_seqs.extend( - map( + def _consensi_for_key(k: str): + return map( lambda a: consensus_seq( - list(map(lambda rr: read_dict_extra[rr]["_tr_seq"], a)), + list(map(lambda rr: read_dict_extra[rr][k], a)), logger_, max_mdn_poa_length, ), allele_reads, ) - ) + + call_seqs.extend(_consensi_for_key("_tr_seq")) + call_anchor_seqs.extend(_consensi_for_key("_start_anchor")) peak_data = { "means": call_peaks, @@ -1594,7 +1610,7 @@ def get_read_length_partition_mean(p_idx: int) -> float: "modal_n": call_modal_n, "n_reads": call_peak_n_reads, **({"kmers": list(map(dict, peak_kmers))} if count_kmers in ("peak", "both") else {}), - **({"seqs": call_seqs} if consensus else {}), + **({"seqs": call_seqs, "start_anchor_seqs": call_anchor_seqs} if consensus else {}), } if call_data else None assign_time = time.perf_counter() - assign_start_time diff --git a/strkit/call/output/vcf.py b/strkit/call/output/vcf.py index 2994773..209c283 100644 --- a/strkit/call/output/vcf.py +++ b/strkit/call/output/vcf.py @@ -3,7 +3,7 @@ import pathlib import pysam -from os.path import commonprefix +# from os.path import commonprefix from typing import Optional from strkit.utils import cat_strs, is_none @@ -120,26 +120,47 @@ def output_contig_vcf_lines( res_reads = result["reads"] res_peaks = result["peaks"] or {} + peak_seqs: list[str] = list(map(idx_0_getter, res_peaks.get("seqs", []))) + peak_start_anchor_seqs: list[str] = list(map(idx_0_getter, res_peaks.get("start_anchor_seqs", []))) + if any(map(is_none, peak_seqs)): # Occurs when no consensus for one of the peaks logger.error(f"Encountered None in results[{result_idx}].peaks.seqs: {peak_seqs}") continue + if any(map(is_none, peak_start_anchor_seqs)): # Occurs when no consensus for one of the peaks + logger.error(f"Encountered None in results[{result_idx}].peaks.start_anchor_seqs: {peak_start_anchor_seqs}") + continue + seqs = tuple(map(str.upper, peak_seqs)) + seqs_with_anchors = tuple(zip(seqs, tuple(map(str.upper, peak_start_anchor_seqs)))) + if 0 < len(seqs) < n_alleles: seqs = tuple([seqs[0]] * n_alleles) + seqs_with_anchors = tuple([seqs_with_anchors[0]] * n_alleles) - seq_alts = sorted(set(filter(lambda c: c != ref_seq, seqs))) - common_suffix_idx = -1 * len(commonprefix(tuple(map(_reversed_str, (ref_seq, *seqs))))) + seq_alts = sorted( + set(filter(lambda c: not (c[0] == ref_seq and c[1] == ref_start_anchor), seqs_with_anchors)), + key=lambda x: x[0] + ) + + # common_suffix_idx = -1 * len(commonprefix(tuple(map(_reversed_str, (ref_seq, *seqs))))) call = result["call"] call_95_cis = result["call_95_cis"] - seq_alleles_raw: tuple[Optional[str], ...] = (ref_seq, *(seq_alts or (None,))) if call is not None else (".",) - seq_alleles: list[str] = [ref_start_anchor + (ref_seq[:common_suffix_idx] if common_suffix_idx else ref_seq)] + seq_alleles_raw: tuple[Optional[str], ...] = ( + ((ref_seq, ref_start_anchor), *(seq_alts or (None,))) + if call is not None + else () + ) + # seq_alleles: list[str] = [ref_start_anchor + (ref_seq[:common_suffix_idx] if common_suffix_idx else ref_seq)] + seq_alleles: list[str] = [ref_start_anchor + ref_seq] if call is not None and seq_alts: - seq_alleles.extend(ref_start_anchor + (a[:common_suffix_idx] if common_suffix_idx else a) for a in seq_alts) + # seq_alleles.extend(a[1] + (a[0][:common_suffix_idx] if common_suffix_idx else a[0]) for a in seq_alts) + # If we have a complete deletion, including the anchor, use a symbolic allele meaning "upstream deletion" + seq_alleles.extend((a[1] + a[0] if a[1] or a[0] else "*") for a in seq_alts) else: seq_alleles.append(".") @@ -155,8 +176,12 @@ def output_contig_vcf_lines( vr.info[VCF_INFO_MOTIF] = result["motif"] vr.info[VCF_INFO_REFMC] = result["ref_cn"] - vr.samples[sample_id]["GT"] = tuple(map(seq_alleles_raw.index, seqs)) if call is not None and seqs \ + vr.samples[sample_id]["GT"] = ( + tuple(map(seq_alleles_raw.index, seqs_with_anchors)) + if call is not None and seqs else _blank_entry(n_alleles) + ) + del seq_alleles_raw if am := result.get("assign_method"): vr.samples[sample_id]["PM"] = am diff --git a/strkit/call/types.py b/strkit/call/types.py index 9c87c27..204baed 100644 --- a/strkit/call/types.py +++ b/strkit/call/types.py @@ -58,7 +58,10 @@ class ReadDictExtra(TypedDict, total=False): _ref_start: int # Read start in ref coordinates _ref_end: int # Read end in ref coordinates - _tr_seq: str # Tandem repeat sequence... only added if consensus is being calculated + # BEGIN: only added if consensus is being calculated + _start_anchor: str # Left anchor for calculated allele sequence (usually 1 base) + _tr_seq: str # Tandem repeat sequence + # END: only added if consensus is being calculated # Below are only added if SNVs are being incorporated: