Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix RuntimeError when initialising from scratch
When initialising from scratch, `requires_grad` is passed and then `normal_` is called as below ``` w = torch.empty( ... , requires_grad=self.requires_grad) w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2)) ``` Causing the following issue: ``` /usr/local/lib/python3.7/dist-packages/dall_e/utils.py in __attrs_post_init__(self) 22 size = (self.n_out, self.n_in, self.kw, self.kw) 23 w = torch.empty(size=size, dtype=torch.float32, device=self.device, requires_grad = self.requires_grad) ---> 24 w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2)) 25 RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. ``` The above change fixes it.
- Loading branch information