Skip to content

Commit

Permalink
hotfix multiple prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
shyamsn97 committed Feb 16, 2023
1 parent d7c1ab7 commit 70bbdb8
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion mario_gpt/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ def sample(
self.lm.eval()
with torch.no_grad():
if seed is None:
seed = self.tokenizer("X", return_tensors="pt").input_ids.view(1, 1)
seed = (
self.tokenizer("X", return_tensors="pt")
.input_ids.view(1, 1)
.repeat(len(prompts), 1)
)
out = seed.to(self.device)
if encoder_hidden_states is None:
if prompts is not None:
Expand Down

0 comments on commit 70bbdb8

Please sign in to comment.