diff --git a/.gitignore b/.gitignore index 0bb21117..57d8d33c 100644 --- a/.gitignore +++ b/.gitignore @@ -144,3 +144,4 @@ dmypy.json # Pyre type checker .pyre/ +src/dnadiffusion/models/unet_attention_induction.py diff --git a/src/dnadiffusion/.DS_Store b/src/dnadiffusion/.DS_Store new file mode 100644 index 00000000..26c154e2 Binary files /dev/null and b/src/dnadiffusion/.DS_Store differ diff --git a/src/dnadiffusion/models/diffusion.py b/src/dnadiffusion/models/diffusion.py index 9be47ff3..a8786a81 100644 --- a/src/dnadiffusion/models/diffusion.py +++ b/src/dnadiffusion/models/diffusion.py @@ -46,6 +46,16 @@ def sample(self, classes, shape, cond_weight): image_size=shape, cond_weight=cond_weight, ) + + + @torch.no_grad() + def sample_cross(self, classes, shape, cond_weight): + return self.p_sample_loop( + classes=classes, + image_size=shape, + cond_weight=cond_weight, + get_cross_map=True, + ) @torch.no_grad() def p_sample_loop(self, classes, image_size, cond_weight, get_cross_map=False): diff --git a/src/dnadiffusion/utils/sample_util.py b/src/dnadiffusion/utils/sample_util.py index 4a1718ca..9832c19c 100644 --- a/src/dnadiffusion/utils/sample_util.py +++ b/src/dnadiffusion/utils/sample_util.py @@ -18,9 +18,14 @@ def create_sample( cond_weight_to_metric: int = 0, save_timesteps: bool = False, save_dataframe: bool = False, + generate_attention_maps: bool = False, ): + + + print ('sample_util') nucleotides = ["A", "C", "G", "T"] final_sequences = [] + cross_maps = [] for n_a in range(number_of_samples): print(n_a) sample_bs = 10 @@ -30,8 +35,21 @@ def create_sample( sampled = torch.from_numpy(np.random.choice(cell_types, sample_bs)) classes = sampled.float().to(diffusion_model.device) - sampled_images = diffusion_model.sample(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric) + + if generate_attention_maps: + sampled_images,cross_att_values = diffusion_model.sample_cross(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric) + #save cross attention maps in a numpy array + np.save(f"cross_att_values_{conditional_numeric_to_tag[group_number]}.npy",cross_att_values) + + else: + sampled_images = diffusion_model.sample(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric) + + + + + + if save_timesteps: seqs_to_df = {} for en, step in enumerate(sampled_images):