From d045e720b1cc6610490a83aac2ff0db5d2b24bb3 Mon Sep 17 00:00:00 2001 From: Stephen Merity Date: Sat, 25 Nov 2017 15:02:12 -0800 Subject: [PATCH] Bugfix: CPUForgetMult did not like batch size 1 due to squeeze --- torchqrnn/forget_mult.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchqrnn/forget_mult.py b/torchqrnn/forget_mult.py index 3633463..6967850 100644 --- a/torchqrnn/forget_mult.py +++ b/torchqrnn/forget_mult.py @@ -82,7 +82,10 @@ def forward(self, f, x, hidden_init=None): prev_h = hidden_init for i, h in enumerate((f * x).split(1, dim=0)): if prev_h is not None: h = h + (1 - forgets[i]) * prev_h - result.append(h.squeeze()) + # h is (1, batch, hidden) when it needs to be (batch_hidden) + # Calling squeeze will result in badness if batch size is 1 + h = h.view(h.size()[1:]) + result.append(h) prev_h = h ### return torch.stack(result) @@ -195,6 +198,7 @@ def forward(self, f, x, hidden_init=None, use_cuda=True): print('=-=-' * 5) resulta = ForgetMult()(forget, a, last_h, use_cuda=True) + print(resulta.size()) loss = resulta.pow(2).sum() loss.backward()