Skip to content

Commit

Permalink
add param record_history
Browse files Browse the repository at this point in the history
  • Loading branch information
LaraFuhrmann committed Nov 27, 2024
1 parent 739eeab commit 491469f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 27 deletions.
63 changes: 37 additions & 26 deletions viloca/local_haplotype_inference/use_quality_scores/cavi.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def multistart_cavi(
reads_log_error_proba,
n_starts,
output_dir,
convergence_threshold
convergence_threshold,
record_history
):

pool = mp.Pool(mp.cpu_count())
Expand All @@ -51,7 +52,8 @@ def multistart_cavi(
reads_log_error_proba,
start,
output_dir,
convergence_threshold
convergence_threshold,
record_history
),
callback=collect_result,
)
Expand All @@ -74,6 +76,7 @@ def run_cavi(
start_id,
output_dir,
convergence_threshold,
record_history,
):
"""
Runs cavi (coordinate ascent variational inference).
Expand All @@ -86,12 +89,6 @@ def run_cavi(
"alphabet": alphabet,
}

history_alpha = []
history_mean_log_pi = []
history_mean_log_gamma = []
history_mean_cluster = []
history_elbo = []

state_init_dict = initialization.draw_init_state(
n_cluster, alpha0, alphabet, reads_list, reference_binary
)
Expand All @@ -104,10 +101,11 @@ def run_cavi(
}
)

history_alpha = [state_init_dict["alpha"]]
history_mean_log_pi = [state_init_dict["mean_log_pi"]]
history_mean_log_gamma = [state_init_dict["mean_log_gamma"]]
history_mean_cluster = [state_init_dict["mean_cluster"]]
if record_history:
history_alpha = [state_init_dict["alpha"]]
history_mean_log_pi = [state_init_dict["mean_log_pi"]]
history_mean_log_gamma = [state_init_dict["mean_log_gamma"]]
history_mean_cluster = [state_init_dict["mean_cluster"]]
history_elbo = []

# Iteratively update mean values
Expand Down Expand Up @@ -141,11 +139,13 @@ def run_cavi(
state_curr_dict,
)

if iter % 2 == 0:
if (iter % 2 == 0) and record_history:
history_elbo.append(elbo)
history_mean_log_pi.append(state_curr_dict["mean_log_pi"])
history_mean_log_gamma.append(state_curr_dict["mean_log_gamma"])
history_mean_cluster.append(state_curr_dict["mean_cluster"])
else:
history_elbo.append(elbo)

if iter > 1:
if np.isnan(elbo):
Expand All @@ -167,19 +167,30 @@ def run_cavi(

state_curr_dict.update({"elbo": elbo})

dict_result.update(
{
"exit_message": exit_message,
"n_iterations": iter,
"converged": converged,
"elbo": elbo,
"history_elbo": history_elbo,
"history_alpha": history_alpha,
"history_mean_log_pi": history_mean_log_pi,
"history_mean_log_gamma": history_mean_log_gamma,
"history_mean_cluster": history_mean_cluster,
}
)
if record_history:
dict_result.update(
{
"exit_message": exit_message,
"n_iterations": iter,
"converged": converged,
"elbo": elbo,
"history_elbo": history_elbo,
"history_alpha": history_alpha,
"history_mean_log_pi": history_mean_log_pi,
"history_mean_log_gamma": history_mean_log_gamma,
"history_mean_cluster": history_mean_cluster,
}
)
else:
dict_result.update(
{
"exit_message": exit_message,
"n_iterations": iter,
"converged": converged,
"elbo": elbo,
"history_elbo": history_elbo,
}
)
dict_result.update(state_curr_dict)

result = (state_curr_dict, dict_result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def main(
reads_log_error_proba,
n_starts,
output_name,
convergence_threshold
convergence_threshold,
record_history
)

else:
Expand All @@ -83,6 +84,7 @@ def main(
0,
output_name,
convergence_threshold,
record_history
)
]

Expand Down

0 comments on commit 491469f

Please sign in to comment.