From 87fd97ae0ba16a3990fb5b0d80e8f91c7600cfdc Mon Sep 17 00:00:00 2001 From: rasbt Date: Fri, 17 May 2024 02:42:34 +0000 Subject: [PATCH] refinements and text fixes --- litgpt/generate/base.py | 7 ++----- tests/test_generate.py | 33 +++++++++++++++++---------------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 384734109f..063e73ddda 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -3,7 +3,7 @@ import sys import time from pathlib import Path -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Generator import lightning as L import torch @@ -79,7 +79,7 @@ def generate( top_p: float = 1.0, eos_id: Optional[int] = None, stream: bool = False, -) -> torch.Tensor: +) -> Generator[torch.Tensor, None, None]: """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. The implementation of this function is modified from A. Karpathy's nanoGPT. @@ -124,9 +124,6 @@ def generate( ).clone() tokens.append(token) - if stream: - yield token - for _ in range(2, max_returned_tokens - T + 1): token = next_token( model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p diff --git a/tests/test_generate.py b/tests/test_generate.py index 430cd6ada0..9c2eaf97f8 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -38,7 +38,7 @@ def multinomial(*args, **kwargs): return out with mock.patch("litgpt.generate.base.multinomial_num_samples_1", multinomial): - out = generate.generate(model, input_idx, T + max_new_tokens, top_k=4) + out = next(generate.generate(model, input_idx, T + max_new_tokens, top_k=4)) assert out.size(0) == T + max_new_tokens multinomial_results = torch.hstack(multinomial_results) @@ -54,16 +54,25 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like): module_mock = Mock() module_mock.config.block_size = 128 + module_mock.max_seq_length = 150 load_mock = Mock() - load_mock.return_value = load_mock + load_mock.return_value = module_mock monkeypatch.setattr(generate, "load_checkpoint", load_mock) tokenizer_mock = Mock() tokenizer_mock.return_value.encode.return_value = torch.tensor([1, 2, 3]) tokenizer_mock.return_value.decode.return_value = "foo bar baz" monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock) - generate_mock = Mock() - generate_mock.return_value = torch.tensor([3, 2, 1]) - monkeypatch.setattr(generate, "generate", generate_mock) + + def generate_mock(model, prompt, max_returned_tokens, *, temperature, top_k, top_p, eos_id, stream): + if stream: + for i in range(max_returned_tokens - prompt.size(0)): + yield torch.tensor([3, 2, 1][i % 3]) + else: + yield torch.cat([prompt] + [torch.tensor([3, 2, 1])] * (max_returned_tokens - prompt.size(0))) + + generate_function_mock = Mock() + generate_function_mock.side_effect = generate_mock + monkeypatch.setattr(generate, "generate", generate_function_mock) num_samples = 2 out, err = StringIO(), StringIO() @@ -71,15 +80,7 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like): generate.main(temperature=2.0, top_k=2, top_p=0.9, num_samples=num_samples, checkpoint_dir=fake_checkpoint_dir) assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples - assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value) - assert ( - generate_mock.mock_calls - == [call(ANY, tensor_like, 53, temperature=2.0, top_k=2, top_p=0.9, eos_id=tokenizer_mock.return_value.eos_id)] - * num_samples - ) - # only the generated result is printed to stdout - assert out.getvalue() == "foo bar baz\n" * num_samples - + assert out.getvalue().strip().split('\n') == ["foo bar baz"] * num_samples assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue() @@ -117,8 +118,8 @@ def test_generate_different_results_with_different_top_p(): input_idx = torch.randint(10, size=(1,)) torch.manual_seed(123) - output1 = generate.generate(model, input_idx, 20, top_p=1.0) + output1 = next(generate.generate(model, input_idx, 20, top_p=1.0)) torch.manual_seed(123) - output2 = generate.generate(model, input_idx, 20, top_p=0.1) + output2 = next(generate.generate(model, input_idx, 20, top_p=0.1)) assert not torch.equal(output1, output2)