From d7cb150575ba91bbfb6a5483e6e92ed15ddd7b5d Mon Sep 17 00:00:00 2001 From: Simon Date: Fri, 1 Sep 2023 14:53:43 -0400 Subject: [PATCH] fix num samples outputted equal to user input (#162) --- src/dnadiffusion/utils/sample_util.py | 36 +++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/dnadiffusion/utils/sample_util.py b/src/dnadiffusion/utils/sample_util.py index 003c51cd..4a1718ca 100644 --- a/src/dnadiffusion/utils/sample_util.py +++ b/src/dnadiffusion/utils/sample_util.py @@ -38,26 +38,10 @@ def create_sample( seqs_to_df[en] = [convert_to_seq(x, nucleotides) for x in step] final_sequences.append(pd.DataFrame(seqs_to_df)) - # Saving dataframe containing sequences for each timestep - pd.concat(final_sequences, ignore_index=True).to_csv( - f"final_{conditional_numeric_to_tag[group_number]}.txt", - header=True, - sep="\t", - index=False, - ) - return - - elif save_dataframe: - seqs_list = [] + if save_dataframe: # Only using the last timestep for en, step in enumerate(sampled_images[-1]): - seqs_list.append(convert_to_seq(step, nucleotides)) - - # saving list of sequences to txt file - with open(f"final_{conditional_numeric_to_tag[group_number]}.txt", "w") as f: - f.write("\n".join(seqs_list)) - return - + final_sequences.append(convert_to_seq(step, nucleotides)) else: for n_b, x in enumerate(sampled_images[-1]): seq_final = f">seq_test_{n_a}_{n_b}\n" + "".join( @@ -65,6 +49,22 @@ def create_sample( ) final_sequences.append(seq_final) + if save_timesteps: + # Saving dataframe containing sequences for each timestep + pd.concat(final_sequences, ignore_index=True).to_csv( + f"final_{conditional_numeric_to_tag[group_number]}.txt", + header=True, + sep="\t", + index=False, + ) + return + + if save_dataframe: + # Saving list of sequences to txt file + with open(f"final_{conditional_numeric_to_tag[group_number]}.txt", "w") as f: + f.write("\n".join(final_sequences)) + return + save_motifs_syn = open("synthetic_motifs.fasta", "w") save_motifs_syn.write("\n".join(final_sequences)) save_motifs_syn.close()