From 01487b0c1fe498c10a248c3ef416904d24a1ad1e Mon Sep 17 00:00:00 2001 From: LucasSilvaFerreira Date: Fri, 8 Sep 2023 13:52:17 -0400 Subject: [PATCH 1/2] adding cross attention --- .gitignore | 1 + src/dnadiffusion/.DS_Store | Bin 0 -> 6148 bytes src/dnadiffusion/models/diffusion.py | 10 ++++++++++ src/dnadiffusion/utils/sample_util.py | 20 +++++++++++++++++++- 4 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 src/dnadiffusion/.DS_Store diff --git a/.gitignore b/.gitignore index 0bb21117..57d8d33c 100644 --- a/.gitignore +++ b/.gitignore @@ -144,3 +144,4 @@ dmypy.json # Pyre type checker .pyre/ +src/dnadiffusion/models/unet_attention_induction.py diff --git a/src/dnadiffusion/.DS_Store b/src/dnadiffusion/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..26c154e2727ff8c63d16b48d252a6395506fc4d3 GIT binary patch literal 6148 zcmeHKI|>3Z5S{S@f{mqRuHX%V=n3`$7J>+(;IH1wb9pr1e41sk(?WRzlb1~9CFB)5 zJ0haX+jb!`6OjqrP#!k)&GyZEHpqwq;W*=RZ_dZV>A36Vz6%(4EH}BzUJf0;?a-(I z6`%rCfC^B7Pb-iWb~63+!90%&P=TLUz`hR!ZdeoBK>u`L@D>0#Lf8#+?GNs z1|kB}paO%c*+Nm*NakF#1^;2XH*JmF@TI|YN6W1yE~EUX;QJt^{v&9Pq- U+d!uy?sOo3222+k75KISF9lB&>Hq)$ literal 0 HcmV?d00001 diff --git a/src/dnadiffusion/models/diffusion.py b/src/dnadiffusion/models/diffusion.py index 9be47ff3..a8786a81 100644 --- a/src/dnadiffusion/models/diffusion.py +++ b/src/dnadiffusion/models/diffusion.py @@ -46,6 +46,16 @@ def sample(self, classes, shape, cond_weight): image_size=shape, cond_weight=cond_weight, ) + + + @torch.no_grad() + def sample_cross(self, classes, shape, cond_weight): + return self.p_sample_loop( + classes=classes, + image_size=shape, + cond_weight=cond_weight, + get_cross_map=True, + ) @torch.no_grad() def p_sample_loop(self, classes, image_size, cond_weight, get_cross_map=False): diff --git a/src/dnadiffusion/utils/sample_util.py b/src/dnadiffusion/utils/sample_util.py index 4a1718ca..9832c19c 100644 --- a/src/dnadiffusion/utils/sample_util.py +++ b/src/dnadiffusion/utils/sample_util.py @@ -18,9 +18,14 @@ def create_sample( cond_weight_to_metric: int = 0, save_timesteps: bool = False, save_dataframe: bool = False, + generate_attention_maps: bool = False, ): + + + print ('sample_util') nucleotides = ["A", "C", "G", "T"] final_sequences = [] + cross_maps = [] for n_a in range(number_of_samples): print(n_a) sample_bs = 10 @@ -30,8 +35,21 @@ def create_sample( sampled = torch.from_numpy(np.random.choice(cell_types, sample_bs)) classes = sampled.float().to(diffusion_model.device) - sampled_images = diffusion_model.sample(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric) + + if generate_attention_maps: + sampled_images,cross_att_values = diffusion_model.sample_cross(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric) + #save cross attention maps in a numpy array + np.save(f"cross_att_values_{conditional_numeric_to_tag[group_number]}.npy",cross_att_values) + + else: + sampled_images = diffusion_model.sample(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric) + + + + + + if save_timesteps: seqs_to_df = {} for en, step in enumerate(sampled_images): From 21187e9eff8f1f65a2348d0c8fb8236275adc3bc Mon Sep 17 00:00:00 2001 From: Simon Date: Fri, 8 Sep 2023 16:50:20 -0400 Subject: [PATCH 2/2] fix formatting after latest PR (#164) --- src/dnadiffusion/.DS_Store | Bin 6148 -> 0 bytes src/dnadiffusion/models/diffusion.py | 1 - src/dnadiffusion/utils/sample_util.py | 20 +++++++------------- 3 files changed, 7 insertions(+), 14 deletions(-) delete mode 100644 src/dnadiffusion/.DS_Store diff --git a/src/dnadiffusion/.DS_Store b/src/dnadiffusion/.DS_Store deleted file mode 100644 index 26c154e2727ff8c63d16b48d252a6395506fc4d3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKI|>3Z5S{S@f{mqRuHX%V=n3`$7J>+(;IH1wb9pr1e41sk(?WRzlb1~9CFB)5 zJ0haX+jb!`6OjqrP#!k)&GyZEHpqwq;W*=RZ_dZV>A36Vz6%(4EH}BzUJf0;?a-(I z6`%rCfC^B7Pb-iWb~63+!90%&P=TLUz`hR!ZdeoBK>u`L@D>0#Lf8#+?GNs z1|kB}paO%c*+Nm*NakF#1^;2XH*JmF@TI|YN6W1yE~EUX;QJt^{v&9Pq- U+d!uy?sOo3222+k75KISF9lB&>Hq)$ diff --git a/src/dnadiffusion/models/diffusion.py b/src/dnadiffusion/models/diffusion.py index a8786a81..a4d0f4ae 100644 --- a/src/dnadiffusion/models/diffusion.py +++ b/src/dnadiffusion/models/diffusion.py @@ -46,7 +46,6 @@ def sample(self, classes, shape, cond_weight): image_size=shape, cond_weight=cond_weight, ) - @torch.no_grad() def sample_cross(self, classes, shape, cond_weight): diff --git a/src/dnadiffusion/utils/sample_util.py b/src/dnadiffusion/utils/sample_util.py index 9832c19c..f5cc2c68 100644 --- a/src/dnadiffusion/utils/sample_util.py +++ b/src/dnadiffusion/utils/sample_util.py @@ -20,9 +20,7 @@ def create_sample( save_dataframe: bool = False, generate_attention_maps: bool = False, ): - - - print ('sample_util') + print("sample_util") nucleotides = ["A", "C", "G", "T"] final_sequences = [] cross_maps = [] @@ -35,21 +33,17 @@ def create_sample( sampled = torch.from_numpy(np.random.choice(cell_types, sample_bs)) classes = sampled.float().to(diffusion_model.device) - + if generate_attention_maps: - sampled_images,cross_att_values = diffusion_model.sample_cross(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric) - #save cross attention maps in a numpy array - np.save(f"cross_att_values_{conditional_numeric_to_tag[group_number]}.npy",cross_att_values) + sampled_images, cross_att_values = diffusion_model.sample_cross( + classes, (sample_bs, 1, 4, 200), cond_weight_to_metric + ) + # save cross attention maps in a numpy array + np.save(f"cross_att_values_{conditional_numeric_to_tag[group_number]}.npy", cross_att_values) - else: sampled_images = diffusion_model.sample(classes, (sample_bs, 1, 4, 200), cond_weight_to_metric) - - - - - if save_timesteps: seqs_to_df = {} for en, step in enumerate(sampled_images):