From 72d92bdf64407cb05e4c8e71c0383685e02bb941 Mon Sep 17 00:00:00 2001 From: rchan Date: Wed, 12 Jun 2024 17:25:48 +0100 Subject: [PATCH] ahhfevq --- reginald/models/models/chat_completion.py | 6 ++--- reginald/models/models/hello.py | 4 ++-- reginald/models/models/llama_index.py | 28 +++++++++++++++++------ reginald/utils.py | 2 +- tests/test_chat_interact.py | 5 ++-- 5 files changed, 29 insertions(+), 16 deletions(-) diff --git a/reginald/models/models/chat_completion.py b/reginald/models/models/chat_completion.py index efb80080..fb0c0bcc 100644 --- a/reginald/models/models/chat_completion.py +++ b/reginald/models/models/chat_completion.py @@ -6,7 +6,7 @@ from openai import AzureOpenAI, OpenAI from reginald.models.models.base import MessageResponse, ResponseModel -from reginald.utils import get_env_var, stream_progress_wrapper +from reginald.utils import get_env_var, stream_iter_progress_wrapper class ChatCompletionBase(ResponseModel): @@ -180,7 +180,7 @@ def stream_message(self, message: str, user_id: str) -> None: stream=True, ) - for chunk in stream_progress_wrapper(response): + for chunk in stream_iter_progress_wrapper(response): print(chunk.choices[0].delta.content, end="", flush=True) @@ -269,5 +269,5 @@ def stream_message(self, message: str, user_id: str) -> None: stream=True, ) - for chunk in stream_progress_wrapper(response): + for chunk in stream_iter_progress_wrapper(response): print(chunk.choices[0].delta.content, end="", flush=True) diff --git a/reginald/models/models/hello.py b/reginald/models/models/hello.py index 6030529b..01ba40a3 100644 --- a/reginald/models/models/hello.py +++ b/reginald/models/models/hello.py @@ -1,7 +1,7 @@ from typing import Any from reginald.models.models.base import MessageResponse, ResponseModel -from reginald.utils import stream_progress_wrapper +from reginald.utils import stream_iter_progress_wrapper class Hello(ResponseModel): @@ -21,5 +21,5 @@ def channel_mention(self, message: str, user_id: str) -> MessageResponse: def stream_message(self, message: str, user_id: str) -> None: # print("\nReginald: ", end="") token_list: tuple[str, ...] = ("Hello", "!", " How", " are", " you", "?") - for token in stream_progress_wrapper(token_list): + for token in stream_iter_progress_wrapper(token_list): print(token, end="", flush=True) diff --git a/reginald/models/models/llama_index.py b/reginald/models/models/llama_index.py index 77eedb91..3d08355b 100644 --- a/reginald/models/models/llama_index.py +++ b/reginald/models/models/llama_index.py @@ -44,7 +44,11 @@ from reginald.models.models.base import MessageResponse, ResponseModel from reginald.models.models.llama_utils import completion_to_prompt, messages_to_prompt -from reginald.utils import get_env_var, stream_progress_wrapper +from reginald.utils import ( + get_env_var, + stream_iter_progress_wrapper, + stream_progress_wrapper, +) nest_asyncio.apply() @@ -632,17 +636,27 @@ def __init__( data_dir=self.data_dir, settings=settings, ) - self.index = data_creator.create_index() - data_creator.save_index() + self.index = stream_progress_wrapper( + data_creator.create_index, + task_str="Generating the index from scratch...", + ) + stream_progress_wrapper( + data_creator.save_index, + task_str="Saving the index...", + ) else: logging.info("Loading the storage context") - storage_context = StorageContext.from_defaults( - persist_dir=self.data_dir / LLAMA_INDEX_DIR / self.which_index + storage_context = stream_progress_wrapper( + StorageContext.from_defaults, + task_str="Loading the storage context...", + persist_dir=self.data_dir / LLAMA_INDEX_DIR / self.which_index, ) logging.info("Loading the pre-processed index") - self.index = load_index_from_storage( + self.index = stream_progress_wrapper( + load_index_from_storage, + task_str="Loading the pre-processed index...", storage_context=storage_context, settings=settings, ) @@ -862,7 +876,7 @@ def stream_message(self, message: str, user_id: str) -> None: self.query_engine._response_synthesizer._streaming = True response_stream = self.query_engine.query(message) - for token in stream_progress_wrapper(response_stream.response_gen): + for token in stream_iter_progress_wrapper(response_stream.response_gen): print(token, end="", flush=True) formatted_response = "\n\n\n" + self._format_sources(response_stream) diff --git a/reginald/utils.py b/reginald/utils.py index 00e9ad1e..ff54dadd 100644 --- a/reginald/utils.py +++ b/reginald/utils.py @@ -58,7 +58,7 @@ def stream_progress_wrapper( end: str = "\n", *args, **kwargs, -) -> chain | Generator | list | tuple | Callable: +) -> Any: """Add a progress bar for iteration. Examples diff --git a/tests/test_chat_interact.py b/tests/test_chat_interact.py index da28c1c8..2379372c 100644 --- a/tests/test_chat_interact.py +++ b/tests/test_chat_interact.py @@ -16,9 +16,8 @@ def test_chat_cli(): result = runner.invoke(cli, ["chat"], input="What's up dock?\nexit\n") term_stdout_lines: list[str] = result.stdout.split("\n") assert term_stdout_lines[0] == ">>> " - assert term_stdout_lines[1] == "Reginald: " - assert term_stdout_lines[2] == "Hello! How are you?" - assert term_stdout_lines[3] == ">>> " + assert term_stdout_lines[1] == "Reginald: Hello! How are you?" + assert term_stdout_lines[2] == ">>> " def test_chat_cli_no_stream():