diff --git a/strkit/call/repeats.py b/strkit/call/repeats.py index 1bc1d20..8466a83 100644 --- a/strkit/call/repeats.py +++ b/strkit/call/repeats.py @@ -1,10 +1,8 @@ -import math import parasail from functools import lru_cache -from typing import Literal, Union +from typing import Literal, Optional, Union -from strkit.utils import sign from .align_matrix import dna_matrix, indel_penalty, match_score from .utils import idx_1_getter @@ -76,47 +74,60 @@ def get_repeat_count( score_diff = abs(start_score - max_init_score) / max_init_score - if score_diff == 0: - return (start_count, start_score), 1 - elif score_diff < 0.05: # TODO: parametrize + if score_diff < 0.05: # TODO: parametrize # If we're very close to the maximum, explore less. local_search_range = 1 elif score_diff < 0.1: local_search_range = 2 - sizes_and_scores: dict[int, int] = {start_count: start_score} - n_scores: int = 1 + explored_sizes: set[int] = {start_count} + best_size: int = start_count + best_score: int = start_score + n_explored: int = 1 to_explore: list[tuple[int, Literal[-1, 1]]] = [(start_count - 1, -1), (start_count + 1, 1)] - while to_explore and n_scores < max_iters: + while to_explore and n_explored < max_iters: size_to_explore, direction = to_explore.pop() if size_to_explore < 0: continue - szs: list[tuple[int, int]] = [] + best_size_this_round: Optional[int] = None + best_score_this_round: int = -99999999999 start_size = max(size_to_explore - (local_search_range if direction == -1 else 0), 0) end_size = size_to_explore + (local_search_range if direction == 1 else 0) for i in range(start_size, end_size + 1): - if i not in sizes_and_scores: + if i not in explored_sizes: # Generate a candidate TR tract by copying the provided motif 'i' times & score it # Separate this from the .get() to postpone computation to until we need it - sizes_and_scores[i] = score_candidate(db_seq_profile, motif, i, flank_left_seq, flank_right_seq) - n_scores += 1 + explored_sizes.add(i) + i_score = score_candidate(db_seq_profile, motif, i, flank_left_seq, flank_right_seq) - szs.append((i, sizes_and_scores[i])) + if best_size_this_round is None or i_score > best_score_this_round: + best_size_this_round = i + best_score_this_round = i_score - mv: tuple[int, int] = max(szs, key=idx_1_getter) - if mv[0] > size_to_explore and (new_rc := mv[0] + 1) not in sizes_and_scores: - if new_rc >= 0: - to_explore.append((new_rc, 1)) - elif mv[0] < size_to_explore and (new_rc := mv[0] - 1) not in sizes_and_scores: - if new_rc >= 0: - to_explore.append((new_rc, -1)) + n_explored += 1 - # noinspection PyTypeChecker - return max(sizes_and_scores.items(), key=idx_1_getter), len(sizes_and_scores) + if best_size_this_round: + if best_size_this_round > size_to_explore and (new_rc := best_size_this_round + 1) not in explored_sizes: + if new_rc >= 0: + to_explore.append((new_rc, 1)) + elif best_size_this_round < size_to_explore and (new_rc := best_size_this_round - 1) not in explored_sizes: + if new_rc >= 0: + to_explore.append((new_rc, -1)) + + # If this round is the best we've got so far, update the record size/score for the final return + if best_score_this_round > best_score: + best_size = best_size_this_round + best_score = best_score_this_round + + if local_search_range > 1 and abs(best_score - max_init_score) / max_init_score < 0.05: + # reduce search range as we approach an optimum + local_search_range = 1 + + return (best_size, best_score), n_explored def get_ref_repeat_count( @@ -191,12 +202,14 @@ def get_ref_repeat_count( # Ignore negative differences (contractions vs TRF definition), but follow expansions # TODO: Should we incorporate contractions? How would that work? - l_offset = sign(rev_top_res[1][1]) * math.floor(abs(rev_top_res[1][1]) / motif_size) * motif_size - r_offset = sign(fwd_top_res[1][1]) * math.floor(abs(fwd_top_res[1][1]) / motif_size) * motif_size + l_offset = rev_top_res[1][1] + r_offset = fwd_top_res[1][1] if l_offset > 0: - flank_left_seq = flank_left_seq[:-1*l_offset] + tr_seq = flank_left_seq[-1*l_offset:] + tr_seq # first, move a chunk of the left flank to the TR seq + flank_left_seq = flank_left_seq[:-1*l_offset] # then, remove that chunk from the left flank if r_offset > 0: + tr_seq = tr_seq + flank_right_seq[:r_offset] # same, but for the right flank flank_right_seq = flank_right_seq[r_offset:] # ------------------------------------------------------------------------------------------------------------------