From 491469f08d494acfb660f5b9069de92617983a03 Mon Sep 17 00:00:00 2001 From: LaraFuhrmann <55209716+LaraFuhrmann@users.noreply.github.com> Date: Wed, 27 Nov 2024 17:32:52 +0100 Subject: [PATCH] add param record_history --- .../use_quality_scores/cavi.py | 63 +++++++++++-------- .../use_quality_scores/run_dpm_mfa.py | 4 +- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/viloca/local_haplotype_inference/use_quality_scores/cavi.py b/viloca/local_haplotype_inference/use_quality_scores/cavi.py index 3ff3c58..7ff06a3 100644 --- a/viloca/local_haplotype_inference/use_quality_scores/cavi.py +++ b/viloca/local_haplotype_inference/use_quality_scores/cavi.py @@ -33,7 +33,8 @@ def multistart_cavi( reads_log_error_proba, n_starts, output_dir, - convergence_threshold + convergence_threshold, + record_history ): pool = mp.Pool(mp.cpu_count()) @@ -51,7 +52,8 @@ def multistart_cavi( reads_log_error_proba, start, output_dir, - convergence_threshold + convergence_threshold, + record_history ), callback=collect_result, ) @@ -74,6 +76,7 @@ def run_cavi( start_id, output_dir, convergence_threshold, + record_history, ): """ Runs cavi (coordinate ascent variational inference). @@ -86,12 +89,6 @@ def run_cavi( "alphabet": alphabet, } - history_alpha = [] - history_mean_log_pi = [] - history_mean_log_gamma = [] - history_mean_cluster = [] - history_elbo = [] - state_init_dict = initialization.draw_init_state( n_cluster, alpha0, alphabet, reads_list, reference_binary ) @@ -104,10 +101,11 @@ def run_cavi( } ) - history_alpha = [state_init_dict["alpha"]] - history_mean_log_pi = [state_init_dict["mean_log_pi"]] - history_mean_log_gamma = [state_init_dict["mean_log_gamma"]] - history_mean_cluster = [state_init_dict["mean_cluster"]] + if record_history: + history_alpha = [state_init_dict["alpha"]] + history_mean_log_pi = [state_init_dict["mean_log_pi"]] + history_mean_log_gamma = [state_init_dict["mean_log_gamma"]] + history_mean_cluster = [state_init_dict["mean_cluster"]] history_elbo = [] # Iteratively update mean values @@ -141,11 +139,13 @@ def run_cavi( state_curr_dict, ) - if iter % 2 == 0: + if (iter % 2 == 0) and record_history: history_elbo.append(elbo) history_mean_log_pi.append(state_curr_dict["mean_log_pi"]) history_mean_log_gamma.append(state_curr_dict["mean_log_gamma"]) history_mean_cluster.append(state_curr_dict["mean_cluster"]) + else: + history_elbo.append(elbo) if iter > 1: if np.isnan(elbo): @@ -167,19 +167,30 @@ def run_cavi( state_curr_dict.update({"elbo": elbo}) - dict_result.update( - { - "exit_message": exit_message, - "n_iterations": iter, - "converged": converged, - "elbo": elbo, - "history_elbo": history_elbo, - "history_alpha": history_alpha, - "history_mean_log_pi": history_mean_log_pi, - "history_mean_log_gamma": history_mean_log_gamma, - "history_mean_cluster": history_mean_cluster, - } - ) + if record_history: + dict_result.update( + { + "exit_message": exit_message, + "n_iterations": iter, + "converged": converged, + "elbo": elbo, + "history_elbo": history_elbo, + "history_alpha": history_alpha, + "history_mean_log_pi": history_mean_log_pi, + "history_mean_log_gamma": history_mean_log_gamma, + "history_mean_cluster": history_mean_cluster, + } + ) + else: + dict_result.update( + { + "exit_message": exit_message, + "n_iterations": iter, + "converged": converged, + "elbo": elbo, + "history_elbo": history_elbo, + } + ) dict_result.update(state_curr_dict) result = (state_curr_dict, dict_result) diff --git a/viloca/local_haplotype_inference/use_quality_scores/run_dpm_mfa.py b/viloca/local_haplotype_inference/use_quality_scores/run_dpm_mfa.py index 2d8ccda..8f2b3e1 100644 --- a/viloca/local_haplotype_inference/use_quality_scores/run_dpm_mfa.py +++ b/viloca/local_haplotype_inference/use_quality_scores/run_dpm_mfa.py @@ -66,7 +66,8 @@ def main( reads_log_error_proba, n_starts, output_name, - convergence_threshold + convergence_threshold, + record_history ) else: @@ -83,6 +84,7 @@ def main( 0, output_name, convergence_threshold, + record_history ) ]