From 06f239b5bde630ea5a18b968b70f63cc046a175b Mon Sep 17 00:00:00 2001 From: LaraFuhrmann <55209716+LaraFuhrmann@users.noreply.github.com> Date: Thu, 28 Nov 2024 15:06:47 +0100 Subject: [PATCH] unique modus and save less for learn_error_params --- .../learn_error_params/cavi.py | 10 +++++----- .../learn_error_params/run_dpm_mfa.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/viloca/local_haplotype_inference/learn_error_params/cavi.py b/viloca/local_haplotype_inference/learn_error_params/cavi.py index a88036a..26d4e5b 100644 --- a/viloca/local_haplotype_inference/learn_error_params/cavi.py +++ b/viloca/local_haplotype_inference/learn_error_params/cavi.py @@ -130,8 +130,8 @@ def run_cavi( converged = False elbo = 0 state_curr_dict = state_init_dict - k = 0 - while converged is False: + min_number_iterations = 10 + while (converged is False) or (iter < min_number_iterations): if iter <= 1: digamma_alpha_sum = digamma(state_curr_dict["alpha"].sum(axis=0)) @@ -160,7 +160,7 @@ def run_cavi( state_init_dict, state_curr_dict, ) - + if iter % 2 == 0: history_elbo.append(elbo) history_mean_log_pi.append(state_curr_dict["mean_log_pi"]) @@ -184,11 +184,11 @@ def run_cavi( break elif np.abs(elbo - history_elbo[-2]) < 1e-03: converged = True - k += 1 + iter += 1 message = "ELBO converged." exitflag = 0 else: - k = 0 + iter = 0 # if k%10==0: # every 10th parameter set is saved to history state_curr_dict.update({"elbo": elbo}) diff --git a/viloca/local_haplotype_inference/learn_error_params/run_dpm_mfa.py b/viloca/local_haplotype_inference/learn_error_params/run_dpm_mfa.py index ffdb8a9..dc47248 100644 --- a/viloca/local_haplotype_inference/learn_error_params/run_dpm_mfa.py +++ b/viloca/local_haplotype_inference/learn_error_params/run_dpm_mfa.py @@ -29,7 +29,7 @@ def gzip_file(f_name): return f_out.name -def main(freads_in, fref_in, output_dir, n_starts, K, alpha0, alphabet="ACGT-", unique_modus=False): +def main(freads_in, fref_in, output_dir, n_starts, K, alpha0, alphabet="ACGT-", unique_modus=True): window_id = freads_in.split("/")[-1][:-4] # freads_in is absolute path