Skip to content

Commit

Permalink
fix cross_attention, removed unused_parameters flag (#236)
Browse files Browse the repository at this point in the history
* fix cross_attention, removed unused_paramters flag

* remove kwargs handler
  • Loading branch information
ssenan authored Oct 13, 2023
1 parent 02e22b4 commit 61fadc1
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/dnadiffusion/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def forward(self, x: torch.Tensor, time: torch.Tensor, classes: torch.Tensor):
self.norm_to_cross(x_reshaped.reshape(-1, 800)).reshape(-1, 4, 200),
context=t_cross_reshaped,
) # (-1,1, 4, 200)
crossattention_out = x.view(-1, 1, 4, 200)
crossattention_out = crossattention_out.view(-1, 1, 4, 200)
x = x + crossattention_out
if self.output_attention:
return x, crossattention_out
Expand Down
3 changes: 1 addition & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@


def train():
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[kwargs], split_batches=True, log_with=["wandb"], mixed_precision="bf16")
accelerator = Accelerator(split_batches=True, log_with=["wandb"], mixed_precision="bf16")

data = load_data(
data_path="src/dnadiffusion/data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt",
Expand Down

0 comments on commit 61fadc1

Please sign in to comment.