Skip to content

Commit

Permalink
Make chat output less verbose (#1123)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 15, 2024
1 parent b5b344d commit 899f29f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
4 changes: 2 additions & 2 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def main(
print("Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.")
merge_lora(checkpoint_path)

fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
with fabric.init_module(empty_init=True):
model = GPT(config)
# enable the kv cache
Expand All @@ -163,13 +162,14 @@ def main(
prompt_style = load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
stop_tokens = prompt_style.stop_tokens(tokenizer)

print(f"Now chatting with {config.name}.\nTo exit, press 'Enter' on an empty prompt.\n")
L.seed_everything(1234)
while True:
try:
prompt = input(">> Prompt: ")
except KeyboardInterrupt:
break
if not prompt:
if prompt.lower().strip() in ("", "quit", "exit"):
break
prompt = prompt_style.apply(prompt=prompt)
encoded_prompt = tokenizer.encode(prompt, device=fabric.device)
Expand Down
18 changes: 12 additions & 6 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import re
import subprocess
import sys
from contextlib import redirect_stderr, redirect_stdout
Expand Down Expand Up @@ -87,7 +87,15 @@ def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, te
mocked_input.side_effect = ["Hello", stop_iteration]

config_path = fake_checkpoint_dir / "model_config.yaml"
config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1}
config = {
"name": "Llama 3",
"block_size": 128,
"vocab_size": 50,
"n_layer": 2,
"n_head": 4,
"n_embd": 8,
"rotary_percentage": 1,
}
config_path.write_text(yaml.dump(config))

load_mock = Mock()
Expand All @@ -112,10 +120,8 @@ def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, te
assert generate_mock.mock_calls == [
call(ANY, tensor_like, 128, temperature=2.0, top_k=2, stop_tokens=([tokenizer_mock.return_value.eos_id],))
]
# # only the generated result is printed to stdout
assert out.getvalue() == ">> Reply: foo bar baz\n"

assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue()
# only the generated result is printed to stdout
assert re.match("Now chatting with Llama 3.*>> .*Reply: foo bar baz", out.getvalue(), re.DOTALL)


@pytest.mark.parametrize("mode", ["file", "entrypoint"])
Expand Down

0 comments on commit 899f29f

Please sign in to comment.