diff --git a/.gitignore b/.gitignore index 31dead6f..2bc75593 100644 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,4 @@ dmypy.json # Pyre type checker .pyre/ +src/dnadiffusion/models/unet_attention_induction.py diff --git a/src/dnadiffusion/models/diffusion.py b/src/dnadiffusion/models/diffusion.py index 9be47ff3..a4d0f4ae 100644 --- a/src/dnadiffusion/models/diffusion.py +++ b/src/dnadiffusion/models/diffusion.py @@ -47,6 +47,15 @@ def sample(self, classes, shape, cond_weight): cond_weight=cond_weight, ) + @torch.no_grad() + def sample_cross(self, classes, shape, cond_weight): + return self.p_sample_loop( + classes=classes, + image_size=shape, + cond_weight=cond_weight, + get_cross_map=True, + ) + @torch.no_grad() def p_sample_loop(self, classes, image_size, cond_weight, get_cross_map=False): b = image_size[0] diff --git a/src/dnadiffusion/utils/sample_util.py b/src/dnadiffusion/utils/sample_util.py index 3a858161..c13a5325 100644 --- a/src/dnadiffusion/utils/sample_util.py +++ b/src/dnadiffusion/utils/sample_util.py @@ -19,7 +19,9 @@ def create_sample( cond_weight_to_metric: int = 0, save_timesteps: bool = False, save_dataframe: bool = False, + generate_attention_maps: bool = False, ): + print("sample_util") nucleotides = ["A", "C", "G", "T"] final_sequences = [] for n_a in tqdm(range(number_of_samples)): @@ -30,7 +32,16 @@ def create_sample( sampled = torch.from_numpy(np.random.choice(cell_types, sample_bs)) classes = sampled.float().to(diffusion_model.device) - sampled_images = diffusion_model.sample(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric) + + if generate_attention_maps: + sampled_images, cross_att_values = diffusion_model.sample_cross( + classes, (sample_bs, 1, 4, 200), cond_weight_to_metric + ) + # save cross attention maps in a numpy array + np.save(f"cross_att_values_{conditional_numeric_to_tag[group_number]}.npy", cross_att_values) + + else: + sampled_images = diffusion_model.sample(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric) if save_timesteps: seqs_to_df = {}