Skip to content

Commit

Permalink
resolve merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
ssenan committed Sep 11, 2023
2 parents 079e1d2 + 21187e9 commit 82d8c3b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,4 @@ dmypy.json

# Pyre type checker
.pyre/
src/dnadiffusion/models/unet_attention_induction.py
9 changes: 9 additions & 0 deletions src/dnadiffusion/models/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ def sample(self, classes, shape, cond_weight):
cond_weight=cond_weight,
)

@torch.no_grad()
def sample_cross(self, classes, shape, cond_weight):
return self.p_sample_loop(
classes=classes,
image_size=shape,
cond_weight=cond_weight,
get_cross_map=True,
)

@torch.no_grad()
def p_sample_loop(self, classes, image_size, cond_weight, get_cross_map=False):
b = image_size[0]
Expand Down
13 changes: 12 additions & 1 deletion src/dnadiffusion/utils/sample_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def create_sample(
cond_weight_to_metric: int = 0,
save_timesteps: bool = False,
save_dataframe: bool = False,
generate_attention_maps: bool = False,
):
print("sample_util")
nucleotides = ["A", "C", "G", "T"]
final_sequences = []
for n_a in tqdm(range(number_of_samples)):
Expand All @@ -30,7 +32,16 @@ def create_sample(
sampled = torch.from_numpy(np.random.choice(cell_types, sample_bs))

classes = sampled.float().to(diffusion_model.device)
sampled_images = diffusion_model.sample(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric)

if generate_attention_maps:
sampled_images, cross_att_values = diffusion_model.sample_cross(
classes, (sample_bs, 1, 4, 200), cond_weight_to_metric
)
# save cross attention maps in a numpy array
np.save(f"cross_att_values_{conditional_numeric_to_tag[group_number]}.npy", cross_att_values)

else:
sampled_images = diffusion_model.sample(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric)

if save_timesteps:
seqs_to_df = {}
Expand Down

0 comments on commit 82d8c3b

Please sign in to comment.