Skip to content

Commit

Permalink
add missing tokens, add test
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 17, 2024
1 parent ea895cc commit e952f80
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
3 changes: 3 additions & 0 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def generate(
).clone()
tokens.append(token)

if stream: # Otherwise 1 token is missing (see tests)
yield token

for _ in range(2, max_returned_tokens - T + 1):
token = next_token(
model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p
Expand Down
29 changes: 29 additions & 0 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,35 @@ def multinomial(*args, **kwargs):
torch.testing.assert_close(out, expected)


@pytest.mark.parametrize(
"max_returned_tokens",
[15, 25, 20]
)
def test_generate_stream(max_returned_tokens):
T = 5
prompt = torch.randint(10, size=(T,))
config = Config(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8)
model = GPT(config)
model.max_seq_length = 30
max_new_tokens = max_returned_tokens - T

model.set_kv_cache(batch_size=1)

multinomial_results = []

def multinomial(*args, **kwargs):
out = torch.multinomial(*args, **kwargs, num_samples=1)
multinomial_results.append(out)
return out

with mock.patch("litgpt.generate.base.multinomial_num_samples_1", multinomial):
token_generator = generate.generate(model, prompt, max_returned_tokens, stream=True)
generated_tokens = list(token_generator)

expected_length = min(max_new_tokens, len(multinomial_results))
assert len(generated_tokens) == expected_length


def test_main(fake_checkpoint_dir, monkeypatch, tensor_like):
config_path = fake_checkpoint_dir / "model_config.yaml"
config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1}
Expand Down

0 comments on commit e952f80

Please sign in to comment.