Skip to content

Commit

Permalink
ahhfevq
Browse files Browse the repository at this point in the history
  • Loading branch information
rchan26 committed Jun 12, 2024
1 parent 8430c66 commit 72d92bd
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 16 deletions.
6 changes: 3 additions & 3 deletions reginald/models/models/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from openai import AzureOpenAI, OpenAI

from reginald.models.models.base import MessageResponse, ResponseModel
from reginald.utils import get_env_var, stream_progress_wrapper
from reginald.utils import get_env_var, stream_iter_progress_wrapper


class ChatCompletionBase(ResponseModel):
Expand Down Expand Up @@ -180,7 +180,7 @@ def stream_message(self, message: str, user_id: str) -> None:
stream=True,
)

for chunk in stream_progress_wrapper(response):
for chunk in stream_iter_progress_wrapper(response):
print(chunk.choices[0].delta.content, end="", flush=True)


Expand Down Expand Up @@ -269,5 +269,5 @@ def stream_message(self, message: str, user_id: str) -> None:
stream=True,
)

for chunk in stream_progress_wrapper(response):
for chunk in stream_iter_progress_wrapper(response):
print(chunk.choices[0].delta.content, end="", flush=True)
4 changes: 2 additions & 2 deletions reginald/models/models/hello.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

from reginald.models.models.base import MessageResponse, ResponseModel
from reginald.utils import stream_progress_wrapper
from reginald.utils import stream_iter_progress_wrapper


class Hello(ResponseModel):
Expand All @@ -21,5 +21,5 @@ def channel_mention(self, message: str, user_id: str) -> MessageResponse:
def stream_message(self, message: str, user_id: str) -> None:
# print("\nReginald: ", end="")
token_list: tuple[str, ...] = ("Hello", "!", " How", " are", " you", "?")
for token in stream_progress_wrapper(token_list):
for token in stream_iter_progress_wrapper(token_list):
print(token, end="", flush=True)
28 changes: 21 additions & 7 deletions reginald/models/models/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@

from reginald.models.models.base import MessageResponse, ResponseModel
from reginald.models.models.llama_utils import completion_to_prompt, messages_to_prompt
from reginald.utils import get_env_var, stream_progress_wrapper
from reginald.utils import (
get_env_var,
stream_iter_progress_wrapper,
stream_progress_wrapper,
)

nest_asyncio.apply()

Expand Down Expand Up @@ -632,17 +636,27 @@ def __init__(
data_dir=self.data_dir,
settings=settings,
)
self.index = data_creator.create_index()
data_creator.save_index()
self.index = stream_progress_wrapper(
data_creator.create_index,
task_str="Generating the index from scratch...",
)
stream_progress_wrapper(
data_creator.save_index,
task_str="Saving the index...",
)

else:
logging.info("Loading the storage context")
storage_context = StorageContext.from_defaults(
persist_dir=self.data_dir / LLAMA_INDEX_DIR / self.which_index
storage_context = stream_progress_wrapper(
StorageContext.from_defaults,
task_str="Loading the storage context...",
persist_dir=self.data_dir / LLAMA_INDEX_DIR / self.which_index,
)

logging.info("Loading the pre-processed index")
self.index = load_index_from_storage(
self.index = stream_progress_wrapper(
load_index_from_storage,
task_str="Loading the pre-processed index...",
storage_context=storage_context,
settings=settings,
)
Expand Down Expand Up @@ -862,7 +876,7 @@ def stream_message(self, message: str, user_id: str) -> None:
self.query_engine._response_synthesizer._streaming = True
response_stream = self.query_engine.query(message)

for token in stream_progress_wrapper(response_stream.response_gen):
for token in stream_iter_progress_wrapper(response_stream.response_gen):
print(token, end="", flush=True)

formatted_response = "\n\n\n" + self._format_sources(response_stream)
Expand Down
2 changes: 1 addition & 1 deletion reginald/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def stream_progress_wrapper(
end: str = "\n",
*args,
**kwargs,
) -> chain | Generator | list | tuple | Callable:
) -> Any:
"""Add a progress bar for iteration.
Examples
Expand Down
5 changes: 2 additions & 3 deletions tests/test_chat_interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ def test_chat_cli():
result = runner.invoke(cli, ["chat"], input="What's up dock?\nexit\n")
term_stdout_lines: list[str] = result.stdout.split("\n")
assert term_stdout_lines[0] == ">>> "
assert term_stdout_lines[1] == "Reginald: "
assert term_stdout_lines[2] == "Hello! How are you?"
assert term_stdout_lines[3] == ">>> "
assert term_stdout_lines[1] == "Reginald: Hello! How are you?"
assert term_stdout_lines[2] == ">>> "


def test_chat_cli_no_stream():
Expand Down

0 comments on commit 72d92bd

Please sign in to comment.