Skip to content

Commit

Permalink
adding cross attention
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasSilvaFerreira committed Sep 8, 2023
1 parent d7cb150 commit 01487b0
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,4 @@ dmypy.json

# Pyre type checker
.pyre/
src/dnadiffusion/models/unet_attention_induction.py
Binary file added src/dnadiffusion/.DS_Store
Binary file not shown.
10 changes: 10 additions & 0 deletions src/dnadiffusion/models/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def sample(self, classes, shape, cond_weight):
image_size=shape,
cond_weight=cond_weight,
)


@torch.no_grad()
def sample_cross(self, classes, shape, cond_weight):
return self.p_sample_loop(
classes=classes,
image_size=shape,
cond_weight=cond_weight,
get_cross_map=True,
)

@torch.no_grad()
def p_sample_loop(self, classes, image_size, cond_weight, get_cross_map=False):
Expand Down
20 changes: 19 additions & 1 deletion src/dnadiffusion/utils/sample_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ def create_sample(
cond_weight_to_metric: int = 0,
save_timesteps: bool = False,
save_dataframe: bool = False,
generate_attention_maps: bool = False,
):


print ('sample_util')
nucleotides = ["A", "C", "G", "T"]
final_sequences = []
cross_maps = []
for n_a in range(number_of_samples):
print(n_a)
sample_bs = 10
Expand All @@ -30,8 +35,21 @@ def create_sample(
sampled = torch.from_numpy(np.random.choice(cell_types, sample_bs))

classes = sampled.float().to(diffusion_model.device)
sampled_images = diffusion_model.sample(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric)

if generate_attention_maps:
sampled_images,cross_att_values = diffusion_model.sample_cross(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric)
#save cross attention maps in a numpy array
np.save(f"cross_att_values_{conditional_numeric_to_tag[group_number]}.npy",cross_att_values)


else:
sampled_images = diffusion_model.sample(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric)






if save_timesteps:
seqs_to_df = {}
for en, step in enumerate(sampled_images):
Expand Down

0 comments on commit 01487b0

Please sign in to comment.