diff --git a/poetry.lock b/poetry.lock
index 3177b1ff..982350f1 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -6732,4 +6732,4 @@ llama-index-notebooks = ["bitsandbytes", "gradio", "ipykernel", "nbconvert"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.11,<3.12"
-content-hash = "9c31a7068b0c587ac336dcd9fbccd44865ddd836af349d81e725507819b7b844"
+content-hash = "1fc58571fc197416364d44dd56dfbb448bf212fa93668cb8ce1555abec625b16"
diff --git a/pyproject.toml b/pyproject.toml
index b7a9b80a..6ca2093f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -56,6 +56,7 @@ typer = {extras = ["all"], version = "^0.12.3"}
langchain-community = "^0.2.4"
tiktoken = "^0.7.0"
llama-index-embeddings-huggingface = "^0.2.1"
+rich = "^13.7.1"
[tool.poetry.group.dev.dependencies]
@@ -105,10 +106,13 @@ build-backend = "poetry.core.masonry.api"
minversion = "6.0"
testpaths = [
"tests",
+ "reginald",
]
addopts = """
- --cov=estios
+ --cov=reginald
--cov-report=term:skip-covered
--cov-append
--pdbcls=IPython.terminal.debugger:TerminalPdb
+ --doctest-modules
"""
+doctest_optionflags = ["NORMALIZE_WHITESPACE", "ELLIPSIS",]
diff --git a/reginald/cli.py b/reginald/cli.py
index 85a78de8..ec86d970 100644
--- a/reginald/cli.py
+++ b/reginald/cli.py
@@ -25,6 +25,7 @@
"device": "Device to use (ignored if not using llama-index).",
"api_url": "API URL for the Reginald app.",
"emoji": "Emoji to use for the bot.",
+ "streaming": "Whether to use streaming for the chat interaction.",
}
cli = typer.Typer()
@@ -102,6 +103,11 @@ def run_all(
str, typer.Option(envvar="LLAMA_INDEX_DEVICE", help=HELP_TEXT["device"])
] = DEFAULT_ARGS["device"],
) -> None:
+ """
+ Run all the components of the Reginald slack bot.
+ Establishes the connection to the Slack API, sets up the bot,
+ and creates a Reginald model to query from.
+ """
set_up_logging_config(level=20)
main(
cli="run_all",
@@ -135,7 +141,7 @@ def bot(
] = EMOJI_DEFAULT,
) -> None:
"""
- Main function to run the Slack bot which sets up the bot
+ Run the Slack bot which sets up the bot
(which uses an API for responding to messages) and
then establishes a WebSocket connection to the
Socket Mode servers and listens for events.
@@ -213,8 +219,8 @@ def app(
] = DEFAULT_ARGS["device"],
) -> None:
"""
- Main function to run the app which sets up the response model
- and then creates a FastAPI app to serve the model.
+ Sets up the response model and then creates a
+ FastAPI app to serve the model.
The app listens on port 8000 and has two endpoints:
- /direct_message: for obtaining responses from direct messages
@@ -262,6 +268,9 @@ def create_index(
int, typer.Option(envvar="LLAMA_INDEX_NUM_OUTPUT")
] = DEFAULT_ARGS["num_output"],
) -> None:
+ """
+ Create an index for the Reginald model.
+ """
set_up_logging_config(level=20)
main(
cli="create_index",
@@ -288,6 +297,12 @@ def chat(
Optional[str],
typer.Option(envvar="REGINALD_MODEL_NAME", help=HELP_TEXT["model_name"]),
] = None,
+ streaming: Annotated[
+ bool,
+ typer.Option(
+ help=HELP_TEXT["streaming"],
+ ),
+ ] = True,
mode: Annotated[
str, typer.Option(envvar="LLAMA_INDEX_MODE", help=HELP_TEXT["mode"])
] = DEFAULT_ARGS["mode"],
@@ -339,9 +354,13 @@ def chat(
str, typer.Option(envvar="LLAMA_INDEX_DEVICE", help=HELP_TEXT["device"])
] = DEFAULT_ARGS["device"],
) -> None:
+ """
+ Run the chat interaction with the Reginald model.
+ """
set_up_logging_config(level=40)
main(
cli="chat",
+ streaming=streaming,
model=model,
model_name=model_name,
mode=mode,
diff --git a/reginald/models/models/__init__.py b/reginald/models/models/__init__.py
index 561b750c..d018a5c4 100644
--- a/reginald/models/models/__init__.py
+++ b/reginald/models/models/__init__.py
@@ -25,13 +25,13 @@
}
DEFAULTS = {
- "chat-completion-azure": "reginald-curie",
+ "chat-completion-azure": "reginald-gpt4",
"chat-completion-openai": "gpt-3.5-turbo",
"hello": None,
"llama-index-ollama": "llama3",
"llama-index-llama-cpp": "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve/main/llama-2-13b-chat.Q6_K.gguf",
"llama-index-hf": "microsoft/phi-1_5",
- "llama-index-gpt-azure": "reginald-gpt35-turbo",
+ "llama-index-gpt-azure": "reginald-gpt4",
"llama-index-gpt-openai": "gpt-3.5-turbo",
}
diff --git a/reginald/models/models/base.py b/reginald/models/models/base.py
index e93377d9..786fd57b 100644
--- a/reginald/models/models/base.py
+++ b/reginald/models/models/base.py
@@ -28,9 +28,13 @@ def __init__(self, emoji: Optional[str], *args: Any, **kwargs: Any):
Emoji to use for the bot's response
"""
self.emoji = emoji
+ self.mode = "NA"
def direct_message(self, message: str, user_id: str) -> MessageResponse:
raise NotImplementedError
def channel_mention(self, message: str, user_id: str) -> MessageResponse:
raise NotImplementedError
+
+ def stream_message(self, message: str, user_id: str) -> None:
+ raise NotImplementedError
diff --git a/reginald/models/models/chat_completion.py b/reginald/models/models/chat_completion.py
index 9112b5c4..fb0c0bcc 100644
--- a/reginald/models/models/chat_completion.py
+++ b/reginald/models/models/chat_completion.py
@@ -1,5 +1,4 @@
import logging
-import os
import sys
from typing import Any
@@ -7,7 +6,7 @@
from openai import AzureOpenAI, OpenAI
from reginald.models.models.base import MessageResponse, ResponseModel
-from reginald.utils import get_env_var
+from reginald.utils import get_env_var, stream_iter_progress_wrapper
class ChatCompletionBase(ResponseModel):
@@ -155,6 +154,35 @@ def channel_mention(self, message: str, user_id: str) -> MessageResponse:
"""
return self._respond(message=message, user_id=user_id)
+ def stream_message(self, message: str, user_id: str) -> None:
+ if self.mode == "chat":
+ response = self.client.chat.completions.create(
+ model=self.engine,
+ messages=[{"role": "user", "content": message}],
+ frequency_penalty=self.frequency_penalty,
+ max_tokens=self.max_tokens,
+ presence_penalty=self.presence_penalty,
+ stop=None,
+ temperature=self.temperature,
+ top_p=self.top_p,
+ stream=True,
+ )
+ elif self.mode == "query":
+ response = self.client.completions.create(
+ model=self.engine,
+ frequency_penalty=self.frequency_penalty,
+ max_tokens=self.max_tokens,
+ presence_penalty=self.presence_penalty,
+ prompt=message,
+ stop=None,
+ temperature=self.temperature,
+ top_p=self.top_p,
+ stream=True,
+ )
+
+ for chunk in stream_iter_progress_wrapper(response):
+ print(chunk.choices[0].delta.content, end="", flush=True)
+
class ChatCompletionOpenAI(ChatCompletionBase):
def __init__(
@@ -233,3 +261,13 @@ def channel_mention(self, message: str, user_id: str) -> MessageResponse:
Response from the query engine.
"""
return self._respond(message=message, user_id=user_id)
+
+ def stream_message(self, message: str, user_id: str) -> None:
+ response = self.client.chat.completions.create(
+ model=self.model_name,
+ messages=[{"role": "user", "content": message}],
+ stream=True,
+ )
+
+ for chunk in stream_iter_progress_wrapper(response):
+ print(chunk.choices[0].delta.content, end="", flush=True)
diff --git a/reginald/models/models/hello.py b/reginald/models/models/hello.py
index 2560613e..01ba40a3 100644
--- a/reginald/models/models/hello.py
+++ b/reginald/models/models/hello.py
@@ -1,6 +1,7 @@
from typing import Any
from reginald.models.models.base import MessageResponse, ResponseModel
+from reginald.utils import stream_iter_progress_wrapper
class Hello(ResponseModel):
@@ -16,3 +17,9 @@ def direct_message(self, message: str, user_id: str) -> MessageResponse:
def channel_mention(self, message: str, user_id: str) -> MessageResponse:
return MessageResponse(f"Hello <@{user_id}>")
+
+ def stream_message(self, message: str, user_id: str) -> None:
+ # print("\nReginald: ", end="")
+ token_list: tuple[str, ...] = ("Hello", "!", " How", " are", " you", "?")
+ for token in stream_iter_progress_wrapper(token_list):
+ print(token, end="", flush=True)
diff --git a/reginald/models/models/llama_index.py b/reginald/models/models/llama_index.py
index 94da9bc4..3d08355b 100644
--- a/reginald/models/models/llama_index.py
+++ b/reginald/models/models/llama_index.py
@@ -44,7 +44,11 @@
from reginald.models.models.base import MessageResponse, ResponseModel
from reginald.models.models.llama_utils import completion_to_prompt, messages_to_prompt
-from reginald.utils import get_env_var
+from reginald.utils import (
+ get_env_var,
+ stream_iter_progress_wrapper,
+ stream_progress_wrapper,
+)
nest_asyncio.apply()
@@ -632,28 +636,39 @@ def __init__(
data_dir=self.data_dir,
settings=settings,
)
- self.index = data_creator.create_index()
- data_creator.save_index()
+ self.index = stream_progress_wrapper(
+ data_creator.create_index,
+ task_str="Generating the index from scratch...",
+ )
+ stream_progress_wrapper(
+ data_creator.save_index,
+ task_str="Saving the index...",
+ )
else:
logging.info("Loading the storage context")
- storage_context = StorageContext.from_defaults(
- persist_dir=self.data_dir / LLAMA_INDEX_DIR / self.which_index
+ storage_context = stream_progress_wrapper(
+ StorageContext.from_defaults,
+ task_str="Loading the storage context...",
+ persist_dir=self.data_dir / LLAMA_INDEX_DIR / self.which_index,
)
logging.info("Loading the pre-processed index")
- self.index = load_index_from_storage(
+ self.index = stream_progress_wrapper(
+ load_index_from_storage,
+ task_str="Loading the pre-processed index...",
storage_context=storage_context,
settings=settings,
)
- response_mode = "simple_summarize"
+ self.response_mode = "simple_summarize"
if self.mode == "chat":
self.chat_engine = {}
logging.info("Done setting up Huggingface backend for chat engine.")
elif self.mode == "query":
self.query_engine = self.index.as_query_engine(
- response_mode=response_mode, similarity_top_k=k
+ response_mode=self.response_mode,
+ similarity_top_k=k,
)
logging.info("Done setting up Huggingface backend for query engine.")
@@ -693,12 +708,48 @@ def _format_sources(response: RESPONSE_TYPE) -> str:
result = "I read the following documents to compose this answer:\n"
result += "\n\n".join(texts)
+
return result
- def _get_response(self, msg_in: str, user_id: str) -> str:
+ def _prep_llm(self) -> BaseLLM:
"""
- Method to obtain a response from the query/chat engine given
- a message and a user id.
+ Method to prepare the LLM to be used.
+
+ Returns
+ -------
+ BaseLLM
+ LLM to be used.
+
+ Raises
+ ------
+ NotImplemented
+ This must be implemented by a subclass of LlamaIndex.
+ """
+ raise NotImplementedError(
+ "_prep_llm needs to be implemented by a subclass of LlamaIndex."
+ )
+
+ def _prep_tokenizer(self) -> callable[str] | None:
+ """
+ Method to prepare the Tokenizer to be used.
+
+ Returns
+ -------
+ callable[str] | None
+ Tokenizer to use. A callable function on a string.
+ Can also be None if using the default set by LlamaIndex.
+
+ Raises
+ ------
+ NotImplemented
+ """
+ raise NotImplementedError(
+ "_prep_tokenizer needs to be implemented by a subclass of LlamaIndex."
+ )
+
+ def _get_response(self, message: str, user_id: str) -> MessageResponse:
+ """
+ Method to respond to a message in Slack.
Parameters
----------
@@ -709,25 +760,25 @@ def _get_response(self, msg_in: str, user_id: str) -> str:
Returns
-------
- str
- String containing the response from the query engine.
+ MessageResponse
+ Response from the query engine.
"""
- response_mode = "simple_summarize"
try:
if self.mode == "chat":
# create chat engine for user if does not exist
if self.chat_engine.get(user_id) is None:
self.chat_engine[user_id] = self.index.as_chat_engine(
- chat_mode="context",
- response_mode=response_mode,
+ chat_mode="condense_plus_context",
+ response_mode=self.response_mode,
similarity_top_k=self.k,
)
# obtain chat engine for particular user
chat_engine = self.chat_engine[user_id]
- response = chat_engine.chat(msg_in)
+ response = chat_engine.chat(message)
elif self.mode == "query":
- response = self.query_engine.query(msg_in)
+ self.query_engine._response_synthesizer._streaming = False
+ response = self.query_engine.query(message)
# concatenate the response with the resources that it used
formatted_response = (
@@ -735,14 +786,16 @@ def _get_response(self, msg_in: str, user_id: str) -> str:
)
except Exception as e: # ignore: broad-except
formatted_response = self.error_response_template.format(repr(e))
+
pattern = (
r"(?s)^Context information is"
r".*"
r"Given the context information and not prior knowledge, answer the question: "
- rf"{msg_in}"
+ rf"{message}"
r"\n(.*)"
)
m = re.search(pattern, formatted_response)
+
if m:
answer = m.group(1)
else:
@@ -750,47 +803,12 @@ def _get_response(self, msg_in: str, user_id: str) -> str:
"Was expecting a backend response with a regular expression but couldn't find a match."
)
answer = formatted_response
- return answer
-
- def _prep_llm(self) -> BaseLLM:
- """
- Method to prepare the LLM to be used.
-
- Returns
- -------
- BaseLLM
- LLM to be used.
-
- Raises
- ------
- NotImplemented
- This must be implemented by a subclass of LlamaIndex.
- """
- raise NotImplementedError(
- "_prep_llm needs to be implemented by a subclass of LlamaIndex."
- )
- def _prep_tokenizer(self) -> callable[str] | None:
- """
- Method to prepare the Tokenizer to be used.
+ return MessageResponse(answer)
- Returns
- -------
- callable[str] | None
- Tokenizer to use. A callable function on a string.
- Can also be None if using the default set by LlamaIndex.
-
- Raises
- ------
- NotImplemented
- """
- raise NotImplementedError(
- "_prep_tokenizer needs to be implemented by a subclass of LlamaIndex."
- )
-
- def _respond(self, message: str, user_id: str) -> MessageResponse:
+ def direct_message(self, message: str, user_id: str) -> MessageResponse:
"""
- Method to respond to a message in Slack.
+ Method to respond to a direct message in Slack.
Parameters
----------
@@ -804,13 +822,11 @@ def _respond(self, message: str, user_id: str) -> MessageResponse:
MessageResponse
Response from the query engine.
"""
- backend_response = self._get_response(message, user_id)
+ return self._get_response(message=message, user_id=user_id)
- return MessageResponse(backend_response)
-
- def direct_message(self, message: str, user_id: str) -> MessageResponse:
+ def channel_mention(self, message: str, user_id: str) -> MessageResponse:
"""
- Method to respond to a direct message in Slack.
+ Method to respond to a channel mention in Slack.
Parameters
----------
@@ -824,11 +840,11 @@ def direct_message(self, message: str, user_id: str) -> MessageResponse:
MessageResponse
Response from the query engine.
"""
- return self._respond(message=message, user_id=user_id)
+ return self._get_response(message=message, user_id=user_id)
- def channel_mention(self, message: str, user_id: str) -> MessageResponse:
+ def stream_message(self, message: str, user_id: str) -> None:
"""
- Method to respond to a channel mention in Slack.
+ Method to respond to a stream message in Slack.
Parameters
----------
@@ -842,7 +858,36 @@ def channel_mention(self, message: str, user_id: str) -> MessageResponse:
MessageResponse
Response from the query engine.
"""
- return self._respond(message=message, user_id=user_id)
+ try:
+ if self.mode == "chat":
+ # create chat engine for user if does not exist
+ if self.chat_engine.get(user_id) is None:
+ self.chat_engine[user_id] = self.index.as_chat_engine(
+ chat_mode="condense_plus_context",
+ response_mode=self.response_mode,
+ similarity_top_k=self.k,
+ streaming=True,
+ )
+
+ # obtain chat engine for particular user
+ chat_engine = self.chat_engine[user_id]
+ response_stream = chat_engine.stream_chat(message)
+ elif self.mode == "query":
+ self.query_engine._response_synthesizer._streaming = True
+ response_stream = self.query_engine.query(message)
+
+ for token in stream_iter_progress_wrapper(response_stream.response_gen):
+ print(token, end="", flush=True)
+
+ formatted_response = "\n\n\n" + self._format_sources(response_stream)
+
+ for token in re.split(r"(\s+)", formatted_response):
+ print(token, end="", flush=True)
+ except Exception as e: # ignore: broad-except
+ for token in re.split(
+ r"(\s+)", self.error_response_template.format(repr(e))
+ ):
+ print(token, end="", flush=True)
class LlamaIndexOllama(LlamaIndex):
diff --git a/reginald/models/models/llama_utils.py b/reginald/models/models/llama_utils.py
index 21a7181a..372f8856 100644
--- a/reginald/models/models/llama_utils.py
+++ b/reginald/models/models/llama_utils.py
@@ -7,11 +7,12 @@
B_SYS, E_SYS = "<>\n", "\n<>\n\n" # use for Llama2
# B_SYS, E_SYS = "", "\n\n" # use for Mistral
DEFAULT_SYSTEM_PROMPT = """\
-You are a helpful, respectful and honest assistant. \
+You are a helpful, respectful and honest assistant named Reginald. \
Always answer as helpfully as possible and follow ALL given instructions. \
Do not speculate or make up information. \
Do not reference any given instructions or context. \
-If the content is not relevant, just ignore it and provide a helpful response. \
+If the content is not relevant, just ignore it and provide a helpful \
+response without mentioning the context. \
"""
diff --git a/reginald/run.py b/reginald/run.py
index 956257d0..99b617a9 100644
--- a/reginald/run.py
+++ b/reginald/run.py
@@ -21,7 +21,7 @@
LISTENING_MSG: Final[str] = "Listening for requests..."
-async def run_bot(api_url: str | None, emoji: str):
+async def run_bot(api_url: str | None, emoji: str) -> None:
if api_url is None:
logging.error(
"API URL is not set. Please set the REGINALD_API_URL "
@@ -44,7 +44,7 @@ async def run_reginald_app(**kwargs) -> None:
uvicorn.run(app, host="0.0.0.0", port=8000)
-async def run_full_pipeline(**kwargs):
+async def run_full_pipeline(**kwargs) -> None:
# set up response model
response_model = setup_llm(**kwargs)
bot = setup_slack_bot(response_model)
@@ -53,19 +53,37 @@ async def run_full_pipeline(**kwargs):
await connect_client(client)
-def run_chat_interact(**kwargs) -> ResponseModel:
+def run_chat_interact(streaming: bool = False, **kwargs) -> ResponseModel:
# set up response model
response_model = setup_llm(**kwargs)
+ user_id = "command_line_chat"
+
while True:
message = input(">>> ")
- if message == "exit":
+ if message in ["exit", "exit()", "quit()", "bye Reginald"]:
return response_model
-
- response = response_model.direct_message(message=message, user_id="chat")
- print(f"\nReginald: {response.message}")
-
-
-async def connect_client(client: SocketModeClient):
+ if message == "":
+ continue
+ if message in ["clear_history", "\clear_history"]:
+ if (
+ response_model.mode == "chat"
+ and response_model.chat_engine.get(user_id) is not None
+ ):
+ response_model.chat_engine[user_id].reset()
+ print("\nReginald: History cleared.")
+ else:
+ print("\nReginald: No history to clear.")
+ continue
+
+ if streaming:
+ response = response_model.stream_message(message=message, user_id=user_id)
+ print("")
+ else:
+ response = response_model.direct_message(message=message, user_id=user_id)
+ print(f"\nReginald: {response.message}")
+
+
+async def connect_client(client: SocketModeClient) -> None:
await client.connect()
# listen for events
logging.info(LISTENING_MSG)
@@ -73,7 +91,13 @@ async def connect_client(client: SocketModeClient):
await asyncio.sleep(float("inf"))
-def main(cli: str, api_url: str | None = None, emoji: str = EMOJI_DEFAULT, **kwargs):
+def main(
+ cli: str,
+ api_url: str | None = None,
+ emoji: str = EMOJI_DEFAULT,
+ streaming: bool = False,
+ **kwargs,
+):
# initialise logging
if cli == "run_all":
asyncio.run(run_full_pipeline(**kwargs))
@@ -82,7 +106,7 @@ def main(cli: str, api_url: str | None = None, emoji: str = EMOJI_DEFAULT, **kwa
elif cli == "app":
asyncio.run(run_reginald_app(**kwargs))
elif cli == "chat":
- run_chat_interact(**kwargs)
+ run_chat_interact(streaming=streaming, **kwargs)
elif cli == "create_index":
create_index(**kwargs)
else:
diff --git a/reginald/utils.py b/reginald/utils.py
index ec20fb30..ff54dadd 100644
--- a/reginald/utils.py
+++ b/reginald/utils.py
@@ -1,5 +1,90 @@
import logging
import os
+from itertools import chain
+from time import sleep
+from typing import Any, Callable, Final, Generator, Iterable
+
+from rich.progress import Progress, SpinnerColumn, TextColumn
+
+REGINAL_PROMPT: Final[str] = "Reginald: "
+
+
+def stream_iter_progress_wrapper(
+ streamer: Iterable | Callable | chain,
+ task_str: str = REGINAL_PROMPT,
+ progress_bar: bool = True,
+ end: str = "",
+ *args,
+ **kwargs,
+) -> Iterable:
+ """Add a progress bar for iteration.
+
+ Examples
+ --------
+ >>> from time import sleep
+ >>> def sleeper(naps: int = 3) -> Generator[str, None, None]:
+ ... for nap in range(naps):
+ ... sleep(1)
+ ... yield f'nap: {nap}'
+ >>> tuple(stream_iter_progress_wrapper(streamer=sleeper))
+
+ Reginald: ('nap: 0', 'nap: 1', 'nap: 2')
+ >>> tuple(stream_iter_progress_wrapper(
+ ... streamer=sleeper, progress_bar=False))
+ Reginald: ('nap: 0', 'nap: 1', 'nap: 2')
+ """
+ if isinstance(streamer, Callable):
+ streamer = streamer(*args, **kwargs)
+ if progress_bar:
+ with Progress(
+ TextColumn("{task.description}[progress.description]"),
+ SpinnerColumn(),
+ transient=True,
+ ) as progress:
+ if isinstance(streamer, list | tuple):
+ streamer = (item for item in streamer)
+ assert isinstance(streamer, Generator)
+ progress.add_task(task_str)
+ first_item = next(streamer)
+ streamer = chain((first_item,), streamer)
+ print(task_str, end=end)
+ return streamer
+
+
+def stream_progress_wrapper(
+ streamer: Callable,
+ task_str: str = REGINAL_PROMPT,
+ progress_bar: bool = True,
+ end: str = "\n",
+ *args,
+ **kwargs,
+) -> Any:
+ """Add a progress bar for iteration.
+
+ Examples
+ --------
+ >>> from time import sleep
+ >>> def sleeper(seconds: int = 3) -> str:
+ ... sleep(seconds)
+ ... return f'{seconds} seconds nap'
+ >>> stream_progress_wrapper(sleeper)
+
+ Reginald:
+ '3 seconds nap'
+ """
+ if progress_bar:
+ with Progress(
+ TextColumn("{task.description}[progress.description]"),
+ SpinnerColumn(),
+ transient=True,
+ ) as progress:
+ progress.add_task(task_str)
+ results: Any = streamer(*args, **kwargs)
+ print(task_str, end=end)
+ return results
+ else:
+ print(task_str, end=end)
+ return streamer(*args, **kwargs)
def get_env_var(
diff --git a/tests/test_chat_interact.py b/tests/test_chat_interact.py
index c7f935e4..2379372c 100644
--- a/tests/test_chat_interact.py
+++ b/tests/test_chat_interact.py
@@ -16,6 +16,17 @@ def test_chat_cli():
result = runner.invoke(cli, ["chat"], input="What's up dock?\nexit\n")
term_stdout_lines: list[str] = result.stdout.split("\n")
assert term_stdout_lines[0] == ">>> "
+ assert term_stdout_lines[1] == "Reginald: Hello! How are you?"
+ assert term_stdout_lines[2] == ">>> "
+
+
+def test_chat_cli_no_stream():
+ """Test sending an input `str` via `cli` and then exiting."""
+ result = runner.invoke(
+ cli, ["chat", "--no-streaming"], input="What's up dock?\nexit\n"
+ )
+ term_stdout_lines: list[str] = result.stdout.split("\n")
+ assert term_stdout_lines[0] == ">>> "
assert term_stdout_lines[1] == "Reginald: Let's discuss this in a channel!"
assert term_stdout_lines[2] == ">>> "
@@ -24,3 +35,37 @@ def test_chat_interact_exit():
with mock.patch.object(builtins, "input", lambda _: "exit"):
interaction = run_chat_interact(model="hello")
assert isinstance(interaction, Hello)
+
+
+def test_chat_interact_exit_with_bracket():
+ with mock.patch.object(builtins, "input", lambda _: "exit()"):
+ interaction = run_chat_interact(model="hello")
+ assert isinstance(interaction, Hello)
+
+
+def test_chat_interact_quit_with_bracket():
+ with mock.patch.object(builtins, "input", lambda _: "quit()"):
+ interaction = run_chat_interact(model="hello")
+ assert isinstance(interaction, Hello)
+
+
+def test_chat_interact_bye():
+ with mock.patch.object(builtins, "input", lambda _: "bye Reginald"):
+ interaction = run_chat_interact(model="hello")
+ assert isinstance(interaction, Hello)
+
+
+def test_chat_interact_clear_history():
+ result = runner.invoke(cli, ["chat"], input="clear_history\n")
+ term_stdout_lines: list[str] = result.stdout.split("\n")
+ assert term_stdout_lines[0] == ">>> "
+ assert term_stdout_lines[1] == "Reginald: No history to clear."
+ assert term_stdout_lines[2] == ">>> "
+
+
+def test_chat_interact_slash_clear_history():
+ result = runner.invoke(cli, ["chat"], input="\clear_history\n")
+ term_stdout_lines: list[str] = result.stdout.split("\n")
+ assert term_stdout_lines[0] == ">>> "
+ assert term_stdout_lines[1] == "Reginald: No history to clear."
+ assert term_stdout_lines[2] == ">>> "