From 9e6a0bfa42f1b8157f301fe200b39e4dcac4d169 Mon Sep 17 00:00:00 2001 From: David Lougheed Date: Wed, 12 Jun 2024 16:47:36 -0400 Subject: [PATCH] refact(call): move call_alleles_with_gmm procedure to function --- strkit/call/call_locus.py | 63 ++++++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/strkit/call/call_locus.py b/strkit/call/call_locus.py index c3654c2..ca51ef8 100644 --- a/strkit/call/call_locus.py +++ b/strkit/call/call_locus.py @@ -240,12 +240,48 @@ def calculate_read_distance( return distance_matrix +def call_alleles_with_gmm( + params: CallParams, + n_alleles: int, + read_dict: dict[str, ReadDict], + assign_method: str, + # --- + rng: np.random.Generator, + # --- + logger_: logging.Logger, + locus_log_str: str, +) -> CallDict | dict: + # Dicts are ordered in Python; very nice :) + rdvs = tuple(read_dict.values()) + read_cns = np.fromiter(map(cn_getter, rdvs), dtype=np.int_) + read_weights = np.fromiter(map(weight_getter, rdvs), dtype=np.float_) + read_weights /= read_weights.sum() # Normalize to probabilities + + logger_.debug(f"{locus_log_str} - assigning alleles using {assign_method} method with {read_cns.shape[0]} reads") + + return call_alleles( + read_cns, (), + read_weights, (), + params=params, + min_reads=params.min_reads, + n_alleles=n_alleles, + separate_strands=False, + read_bias_corr_min=0, # TODO: parametrize + gm_filter_factor=3, # TODO: parametrize + seed=get_new_seed(rng), + logger_=logger_, + debug_str=locus_log_str, + ) or {} # Still false-y + + def call_alleles_with_haplotags( params: CallParams, haplotags: list[int], ps_id: int, read_dict_items: tuple[tuple[str, ReadDict], ...], # We could derive this again, but we already have before... + # --- rng: np.random.Generator, + # --- logger_: logging.Logger, locus_log_str: str, ) -> Optional[dict]: @@ -266,7 +302,7 @@ def call_alleles_with_haplotags( # Calculate weights array ws = np.fromiter(map(weight_getter, crs), dtype=np.float_) - c_ws.append(ws / np.sum(ws)) + c_ws.append(ws / ws.sum()) hp_reads.append(crs) @@ -710,7 +746,7 @@ def call_alleles_with_incorporated_snvs( # TODO: Readjust peak weights when combining or don't include # Make peak weights sum to 1 - "peak_weights": peak_weights_pre_adj / np.sum(peak_weights_pre_adj), + "peak_weights": peak_weights_pre_adj / peak_weights_pre_adj.sum(), "peak_stdevs": np.concatenate(tuple(cc["peak_stdevs"] for cc in cdd_ordered), axis=0), "modal_n_peaks": n_alleles, # n. of alleles = n. of peaks always -- if we phased using SNVs @@ -1382,28 +1418,7 @@ def call_locus( single_or_dist_assign: bool = assign_method in ("single", "dist") if single_or_dist_assign: # Didn't use SNVs, so call the 'old-fashioned' way - using only copy number - # Dicts are ordered in Python; very nice :) - rdvs = tuple(read_dict.values()) - rcns = tuple(map(cn_getter, rdvs)) - read_cns = np.fromiter(rcns, dtype=np.int_) - read_weights = np.fromiter(map(weight_getter, rdvs), dtype=np.float_) - read_weights = read_weights / np.sum(read_weights) # Normalize to probabilities - - logger_.debug(f"{locus_log_str} - assigning alleles using {assign_method} method with {len(rcns)} reads") - - call_data = call_alleles( - read_cns, (), - read_weights, (), - params=params, - min_reads=params.min_reads, - n_alleles=n_alleles, - separate_strands=False, - read_bias_corr_min=0, # TODO: parametrize - gm_filter_factor=3, # TODO: parametrize - seed=get_new_seed(rng), - logger_=logger_, - debug_str=locus_log_str, - ) or {} # Still false-y + call_data = call_alleles_with_gmm(params, n_alleles, read_dict, assign_method, rng, logger_, locus_log_str) allele_time = (datetime.now() - allele_start_time).total_seconds()