Skip to content
This repository has been archived by the owner on Feb 12, 2022. It is now read-only.

Commit

Permalink
Bugfix: CPUForgetMult did not like batch size 1 due to squeeze
Browse files Browse the repository at this point in the history
  • Loading branch information
Smerity committed Nov 25, 2017
1 parent 2ffbd32 commit d045e72
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion torchqrnn/forget_mult.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def forward(self, f, x, hidden_init=None):
prev_h = hidden_init
for i, h in enumerate((f * x).split(1, dim=0)):
if prev_h is not None: h = h + (1 - forgets[i]) * prev_h
result.append(h.squeeze())
# h is (1, batch, hidden) when it needs to be (batch_hidden)
# Calling squeeze will result in badness if batch size is 1
h = h.view(h.size()[1:])
result.append(h)
prev_h = h
###
return torch.stack(result)
Expand Down Expand Up @@ -195,6 +198,7 @@ def forward(self, f, x, hidden_init=None, use_cuda=True):
print('=-=-' * 5)

resulta = ForgetMult()(forget, a, last_h, use_cuda=True)
print(resulta.size())
loss = resulta.pow(2).sum()
loss.backward()

Expand Down

0 comments on commit d045e72

Please sign in to comment.