From e952f80c7d06c7b0dd4f37b149cd1a39fff14298 Mon Sep 17 00:00:00 2001 From: rasbt Date: Fri, 17 May 2024 15:22:51 +0000 Subject: [PATCH] add missing tokens, add test --- litgpt/generate/base.py | 3 +++ tests/test_generate.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 063e73ddda..e089314eeb 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -124,6 +124,9 @@ def generate( ).clone() tokens.append(token) + if stream: # Otherwise 1 token is missing (see tests) + yield token + for _ in range(2, max_returned_tokens - T + 1): token = next_token( model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p diff --git a/tests/test_generate.py b/tests/test_generate.py index 7f6a2a0d47..43954db4d0 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -47,6 +47,35 @@ def multinomial(*args, **kwargs): torch.testing.assert_close(out, expected) +@pytest.mark.parametrize( + "max_returned_tokens", + [15, 25, 20] +) +def test_generate_stream(max_returned_tokens): + T = 5 + prompt = torch.randint(10, size=(T,)) + config = Config(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8) + model = GPT(config) + model.max_seq_length = 30 + max_new_tokens = max_returned_tokens - T + + model.set_kv_cache(batch_size=1) + + multinomial_results = [] + + def multinomial(*args, **kwargs): + out = torch.multinomial(*args, **kwargs, num_samples=1) + multinomial_results.append(out) + return out + + with mock.patch("litgpt.generate.base.multinomial_num_samples_1", multinomial): + token_generator = generate.generate(model, prompt, max_returned_tokens, stream=True) + generated_tokens = list(token_generator) + + expected_length = min(max_new_tokens, len(multinomial_results)) + assert len(generated_tokens) == expected_length + + def test_main(fake_checkpoint_dir, monkeypatch, tensor_like): config_path = fake_checkpoint_dir / "model_config.yaml" config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1}