From 91b8f625aa2f344fa10e61b1cf66e323b86bf55b Mon Sep 17 00:00:00 2001 From: Yash Bonde Date: Wed, 3 Mar 2021 17:52:54 +0530 Subject: [PATCH] Fix RuntimeError when initialising from scratch When initialising from scratch, `requires_grad` is passed and then `normal_` is called as below ``` w = torch.empty( ... , requires_grad=self.requires_grad) w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2)) ``` Causing the following issue: ``` /usr/local/lib/python3.7/dist-packages/dall_e/utils.py in __attrs_post_init__(self) 22 size = (self.n_out, self.n_in, self.kw, self.kw) 23 w = torch.empty(size=size, dtype=torch.float32, device=self.device, requires_grad = self.requires_grad) ---> 24 w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2)) 25 RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. ``` The above change fixes it. --- dall_e/utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dall_e/utils.py b/dall_e/utils.py index cdb1cad..0f6b2fc 100644 --- a/dall_e/utils.py +++ b/dall_e/utils.py @@ -19,10 +19,13 @@ class Conv2d(nn.Module): def __attrs_post_init__(self) -> None: super().__init__() - - w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32, - device=self.device, requires_grad=self.requires_grad) + size = (self.n_out, self.n_in, self.kw, self.kw) + w = torch.empty(size=size, dtype=torch.float32, device=self.device) w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2)) + + # move requires_grad after filling values using normal_ + # RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. + w.requires_grad = self.requires_grad b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device, requires_grad=self.requires_grad)