Skip to content

Commit

Permalink
refinements and text fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 17, 2024
1 parent f5f6dcc commit 87fd97a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 21 deletions.
7 changes: 2 additions & 5 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import time
from pathlib import Path
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Generator

import lightning as L
import torch
Expand Down Expand Up @@ -79,7 +79,7 @@ def generate(
top_p: float = 1.0,
eos_id: Optional[int] = None,
stream: bool = False,
) -> torch.Tensor:
) -> Generator[torch.Tensor, None, None]:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
The implementation of this function is modified from A. Karpathy's nanoGPT.
Expand Down Expand Up @@ -124,9 +124,6 @@ def generate(
).clone()
tokens.append(token)

if stream:
yield token

for _ in range(2, max_returned_tokens - T + 1):
token = next_token(
model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p
Expand Down
33 changes: 17 additions & 16 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def multinomial(*args, **kwargs):
return out

with mock.patch("litgpt.generate.base.multinomial_num_samples_1", multinomial):
out = generate.generate(model, input_idx, T + max_new_tokens, top_k=4)
out = next(generate.generate(model, input_idx, T + max_new_tokens, top_k=4))

assert out.size(0) == T + max_new_tokens
multinomial_results = torch.hstack(multinomial_results)
Expand All @@ -54,32 +54,33 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like):

module_mock = Mock()
module_mock.config.block_size = 128
module_mock.max_seq_length = 150
load_mock = Mock()
load_mock.return_value = load_mock
load_mock.return_value = module_mock
monkeypatch.setattr(generate, "load_checkpoint", load_mock)
tokenizer_mock = Mock()
tokenizer_mock.return_value.encode.return_value = torch.tensor([1, 2, 3])
tokenizer_mock.return_value.decode.return_value = "foo bar baz"
monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock)
generate_mock = Mock()
generate_mock.return_value = torch.tensor([3, 2, 1])
monkeypatch.setattr(generate, "generate", generate_mock)

def generate_mock(model, prompt, max_returned_tokens, *, temperature, top_k, top_p, eos_id, stream):
if stream:
for i in range(max_returned_tokens - prompt.size(0)):
yield torch.tensor([3, 2, 1][i % 3])
else:
yield torch.cat([prompt] + [torch.tensor([3, 2, 1])] * (max_returned_tokens - prompt.size(0)))

generate_function_mock = Mock()
generate_function_mock.side_effect = generate_mock
monkeypatch.setattr(generate, "generate", generate_function_mock)

num_samples = 2
out, err = StringIO(), StringIO()
with redirect_stdout(out), redirect_stderr(err):
generate.main(temperature=2.0, top_k=2, top_p=0.9, num_samples=num_samples, checkpoint_dir=fake_checkpoint_dir)

assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples
assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)
assert (
generate_mock.mock_calls
== [call(ANY, tensor_like, 53, temperature=2.0, top_k=2, top_p=0.9, eos_id=tokenizer_mock.return_value.eos_id)]
* num_samples
)
# only the generated result is printed to stdout
assert out.getvalue() == "foo bar baz\n" * num_samples

assert out.getvalue().strip().split('\n') == ["foo bar baz"] * num_samples
assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue()


Expand Down Expand Up @@ -117,8 +118,8 @@ def test_generate_different_results_with_different_top_p():
input_idx = torch.randint(10, size=(1,))

torch.manual_seed(123)
output1 = generate.generate(model, input_idx, 20, top_p=1.0)
output1 = next(generate.generate(model, input_idx, 20, top_p=1.0))
torch.manual_seed(123)
output2 = generate.generate(model, input_idx, 20, top_p=0.1)
output2 = next(generate.generate(model, input_idx, 20, top_p=0.1))

assert not torch.equal(output1, output2)

0 comments on commit 87fd97a

Please sign in to comment.