Skip to content

Commit

Permalink
feat(mi): sequence-wise MI calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Dec 9, 2024
1 parent ad9015f commit 21d333b
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 92 deletions.
4 changes: 2 additions & 2 deletions strkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
CALLER_HIPSTR = "hipstr"
CALLER_LONGTR = "longtr"
CALLER_GANGSTR = "gangstr"
CALLER_GENERIC_VCF_AL = "generic-vcf-al"
CALLER_GENERIC_VCF = "generic-vcf"
CALLER_REPEATHMM = "repeathmm"
CALLER_STRAGLR = "straglr"
CALLER_STRKIT = "strkit"
Expand Down Expand Up @@ -48,7 +48,7 @@
MI_CALLERS = (
CALLER_EXPANSIONHUNTER,
CALLER_GANGSTR,
CALLER_GENERIC_VCF_AL,
CALLER_GENERIC_VCF,
CALLER_LONGTR,
CALLER_REPEATHMM,
CALLER_STRAGLR,
Expand Down
2 changes: 1 addition & 1 deletion strkit/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def _exec_mi(p_args) -> None:
calc_classes: dict[str, Type[BaseCalculator]] = {
c.CALLER_EXPANSIONHUNTER: ExpansionHunterCalculator,
c.CALLER_GANGSTR: GangSTRCalculator,
c.CALLER_GENERIC_VCF_AL: GenericVCFLengthCalculator,
c.CALLER_GENERIC_VCF: GenericVCFLengthCalculator,
c.CALLER_LONGTR: GenericVCFLengthCalculator,
c.CALLER_REPEATHMM: RepeatHMMCalculator,
c.CALLER_STRAGLR: StraglrCalculator,
Expand Down
49 changes: 39 additions & 10 deletions strkit/mi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, Union

from strkit.logger import get_main_logger
from .intervals import (
Expand Down Expand Up @@ -165,11 +165,21 @@ def get_trio_contigs(self, include_sex_chromosomes: bool = False) -> set:
def calculate_contig(self, contig: str) -> MIContigResult:
return MIContigResult(contig)

@staticmethod
def _updated_mi_res(res: Optional[float], v: Union[int, float, None]) -> Optional[float]:
return None if v is None else ((res or 0) + v)

def calculate(self, included_contigs: set) -> Optional[MIResult]:
# copy number
res: float = 0
res_pm1: float = 0
res_95_ci: Optional[float] = None
res_99_ci: Optional[float] = None
# sequence
res_seq: Optional[float] = None
res_sl: Optional[float] = None
res_sl_pm1: Optional[float] = None

n_total: int = 0

contig_results = []
Expand All @@ -180,12 +190,23 @@ def calculate(self, included_contigs: set) -> Optional[MIResult]:

contig_result = self.calculate_contig(contig)
contig_results.append(contig_result)

r, nm = contig_result.process_loci(calculate_non_matching=self.test_to_perform == "none")
value, value_pm1, value_95_ci, value_99_ci = r
res += value
res_pm1 += value_pm1
res_95_ci = None if value_95_ci is None else ((res_95_ci or 0) + value_95_ci)
res_99_ci = None if value_99_ci is None else ((res_99_ci or 0) + value_99_ci)

value_95_ci = r["ci_95"]
value_99_ci = r["ci_99"]
value_seq = r["seq"]
value_sl = r["sl"]
value_sl_pm1 = r["sl_pm1"]

res += r["strict"]
res_pm1 += r["pm1"]
res_95_ci = self._updated_mi_res(res_95_ci, value_95_ci)
res_99_ci = self._updated_mi_res(res_99_ci, value_99_ci)
res_seq = self._updated_mi_res(res_seq, value_seq)
res_sl = self._updated_mi_res(res_sl, value_sl)
res_sl_pm1 = self._updated_mi_res(res_sl_pm1, value_sl_pm1)

n_total += len(contig_result)
output_loci.extend(nm)

Expand All @@ -202,12 +223,20 @@ def calculate(self, included_contigs: set) -> Optional[MIResult]:
res_pm1 /= n_total
res_95_ci = None if res_95_ci is None else (res_95_ci / n_total)
res_99_ci = None if res_99_ci is None else (res_99_ci / n_total)
res_seq = None if res_seq is None else (res_seq / n_total)
res_sl = None if res_sl is None else (res_sl / n_total)
res_sl_pm1 = None if res_sl is None else (res_sl_pm1 / n_total)

mi_res = MIResult(
res,
res_pm1,
res_95_ci,
res_99_ci,
{
"strict": res,
"pm1": res_pm1,
"ci_95": res_95_ci,
"ci_99": res_99_ci,
"seq": res_seq,
"sl": res_sl,
"sl_pm1": res_sl_pm1,
},
contig_results,
output_loci,
self._widen,
Expand Down
18 changes: 8 additions & 10 deletions strkit/mi/generic_vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _get_sample_contigs(self) -> tuple[set, set, set]:
return contigs

def calculate_contig(self, contig: str) -> MIContigResult:
cr = MIContigResult(contig)
cr = MIContigResult(contig, includes_seq=True)

mvf = pysam.VariantFile(str(self._mother_call_file))
fvf = pysam.VariantFile(str(self._father_call_file))
Expand Down Expand Up @@ -64,15 +64,12 @@ def calculate_contig(self, contig: str) -> MIContigResult:
ms = mv.samples[self._mother_id or 0]
fs = fv.samples[self._father_id or 0]

c_gt = (
tuple(sorted(round(len(cv.alleles[g]) / motif_len) for g in cs["GT"])) if None not in cs["GT"] else None
)
m_gt = (
tuple(sorted(round(len(mv.alleles[g]) / motif_len) for g in ms["GT"])) if None not in ms["GT"] else None
)
f_gt = (
tuple(sorted(round(len(fv.alleles[g]) / motif_len) for g in fs["GT"])) if None not in fs["GT"] else None
)
c_seq_gt = tuple(sorted((cv.alleles[g] for g in cs["GT"]), key=len)) if None not in cs["GT"] else None
c_gt = tuple(round(len(a) / motif_len) for a in c_seq_gt) if c_seq_gt is not None else None
m_seq_gt = tuple(sorted((mv.alleles[g] for g in ms["GT"]), key=len)) if None not in ms["GT"] else None
m_gt = tuple(round(len(a) / motif_len) for a in m_seq_gt) if m_seq_gt is not None else None
f_seq_gt = tuple(sorted((fv.alleles[g] for g in fs["GT"]), key=len)) if None not in fs["GT"] else None
f_gt = tuple(round(len(a) / motif_len) for a in f_seq_gt) if f_seq_gt is not None else None

if c_gt is None or m_gt is None or f_gt is None:
# None call in VCF, skip this call
Expand All @@ -85,6 +82,7 @@ def calculate_contig(self, contig: str) -> MIContigResult:
motif=motif,

child_gt=c_gt, mother_gt=m_gt, father_gt=f_gt,
child_seq_gt=c_seq_gt, mother_seq_gt=m_seq_gt, father_seq_gt=f_seq_gt,
))

return cr
Loading

0 comments on commit 21d333b

Please sign in to comment.