Skip to content

Commit

Permalink
fix num samples outputted equal to user input (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssenan authored Sep 1, 2023
1 parent 4c31977 commit d7cb150
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions src/dnadiffusion/utils/sample_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,33 +38,33 @@ def create_sample(
seqs_to_df[en] = [convert_to_seq(x, nucleotides) for x in step]
final_sequences.append(pd.DataFrame(seqs_to_df))

# Saving dataframe containing sequences for each timestep
pd.concat(final_sequences, ignore_index=True).to_csv(
f"final_{conditional_numeric_to_tag[group_number]}.txt",
header=True,
sep="\t",
index=False,
)
return

elif save_dataframe:
seqs_list = []
if save_dataframe:
# Only using the last timestep
for en, step in enumerate(sampled_images[-1]):
seqs_list.append(convert_to_seq(step, nucleotides))

# saving list of sequences to txt file
with open(f"final_{conditional_numeric_to_tag[group_number]}.txt", "w") as f:
f.write("\n".join(seqs_list))
return

final_sequences.append(convert_to_seq(step, nucleotides))
else:
for n_b, x in enumerate(sampled_images[-1]):
seq_final = f">seq_test_{n_a}_{n_b}\n" + "".join(
[nucleotides[s] for s in np.argmax(x.reshape(4, 200), axis=0)]
)
final_sequences.append(seq_final)

if save_timesteps:
# Saving dataframe containing sequences for each timestep
pd.concat(final_sequences, ignore_index=True).to_csv(
f"final_{conditional_numeric_to_tag[group_number]}.txt",
header=True,
sep="\t",
index=False,
)
return

if save_dataframe:
# Saving list of sequences to txt file
with open(f"final_{conditional_numeric_to_tag[group_number]}.txt", "w") as f:
f.write("\n".join(final_sequences))
return

save_motifs_syn = open("synthetic_motifs.fasta", "w")
save_motifs_syn.write("\n".join(final_sequences))
save_motifs_syn.close()
Expand Down

0 comments on commit d7cb150

Please sign in to comment.