From 899f29f6236b482190afa5e5ca8443021be0cf8a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 14 Mar 2024 22:13:03 +0100 Subject: [PATCH] Make chat output less verbose (#1123) --- litgpt/chat/base.py | 4 ++-- tests/test_chat.py | 18 ++++++++++++------ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/litgpt/chat/base.py b/litgpt/chat/base.py index 84023d3a26..2926462305 100644 --- a/litgpt/chat/base.py +++ b/litgpt/chat/base.py @@ -142,7 +142,6 @@ def main( print("Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.") merge_lora(checkpoint_path) - fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) with fabric.init_module(empty_init=True): model = GPT(config) # enable the kv cache @@ -163,13 +162,14 @@ def main( prompt_style = load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config) stop_tokens = prompt_style.stop_tokens(tokenizer) + print(f"Now chatting with {config.name}.\nTo exit, press 'Enter' on an empty prompt.\n") L.seed_everything(1234) while True: try: prompt = input(">> Prompt: ") except KeyboardInterrupt: break - if not prompt: + if prompt.lower().strip() in ("", "quit", "exit"): break prompt = prompt_style.apply(prompt=prompt) encoded_prompt = tokenizer.encode(prompt, device=fabric.device) diff --git a/tests/test_chat.py b/tests/test_chat.py index 96ac0bfe2a..98245f11f9 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,5 +1,5 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - +import re import subprocess import sys from contextlib import redirect_stderr, redirect_stdout @@ -87,7 +87,15 @@ def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, te mocked_input.side_effect = ["Hello", stop_iteration] config_path = fake_checkpoint_dir / "model_config.yaml" - config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1} + config = { + "name": "Llama 3", + "block_size": 128, + "vocab_size": 50, + "n_layer": 2, + "n_head": 4, + "n_embd": 8, + "rotary_percentage": 1, + } config_path.write_text(yaml.dump(config)) load_mock = Mock() @@ -112,10 +120,8 @@ def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, te assert generate_mock.mock_calls == [ call(ANY, tensor_like, 128, temperature=2.0, top_k=2, stop_tokens=([tokenizer_mock.return_value.eos_id],)) ] - # # only the generated result is printed to stdout - assert out.getvalue() == ">> Reply: foo bar baz\n" - - assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue() + # only the generated result is printed to stdout + assert re.match("Now chatting with Llama 3.*>> .*Reply: foo bar baz", out.getvalue(), re.DOTALL) @pytest.mark.parametrize("mode", ["file", "entrypoint"])