Skip to content

Commit

Permalink
refactor so models and modules loaded only when necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
rchan26 committed Jun 12, 2024
1 parent dc9a7fe commit 0d2737b
Show file tree
Hide file tree
Showing 33 changed files with 1,939 additions and 1,627 deletions.
52 changes: 48 additions & 4 deletions reginald/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

import typer

from reginald.models.setup_llm import DEFAULT_ARGS
from reginald.run import EMOJI_DEFAULT, main
from reginald.defaults import DEFAULT_ARGS, EMOJI_DEFAULT
from reginald.run import main

API_URL_PROPMPT = "No API URL was provided and REGINALD_API_URL not set. Please provide an API URL for the Reginald app"
PROMPTS = {
"api_url": "No API URL was provided and REGINALD_API_URL not set. Please provide an API URL for the Reginald app",
"slack_app_token": "No Slack app token was provided and SLACK_APP_TOKEN not set. Please provide a Slack app token for the bot",
"slack_bot_token": "No Slack bot token was provided and SLACK_BOT_TOKEN not set. Please provide a Slack bot token for the bot",
}
HELP_TEXT = {
"model": "Select which type of model to use..",
"model_name": "Select which sub-model to use (within the main model selected).",
Expand All @@ -26,6 +30,8 @@
"api_url": "API URL for the Reginald app.",
"emoji": "Emoji to use for the bot.",
"streaming": "Whether to use streaming for the chat interaction.",
"slack_app_token": "Slack app token for the bot.",
"slack_bot_token": "Slack bot token for the bot.",
}

cli = typer.Typer()
Expand All @@ -41,6 +47,22 @@ def set_up_logging_config(level: int = 20) -> None:

@cli.command()
def run_all(
slack_app_token: Annotated[
str,
typer.Option(
prompt=PROMPTS["slack_app_token"],
envvar="SLACK_APP_TOKEN",
help=HELP_TEXT["slack_app_token"],
),
],
slack_bot_token: Annotated[
str,
typer.Option(
prompt=PROMPTS["slack_bot_token"],
envvar="SLACK_BOT_TOKEN",
help=HELP_TEXT["slack_bot_token"],
),
],
model: Annotated[
str,
typer.Option(
Expand Down Expand Up @@ -111,6 +133,8 @@ def run_all(
set_up_logging_config(level=20)
main(
cli="run_all",
slack_app_token=slack_app_token,
slack_bot_token=slack_bot_token,
model=model,
model_name=model_name,
mode=mode,
Expand All @@ -130,10 +154,28 @@ def run_all(

@cli.command()
def bot(
slack_app_token: Annotated[
str,
typer.Option(
prompt=PROMPTS["slack_app_token"],
envvar="SLACK_APP_TOKEN",
help=HELP_TEXT["slack_app_token"],
),
],
slack_bot_token: Annotated[
str,
typer.Option(
prompt=PROMPTS["slack_bot_token"],
envvar="SLACK_BOT_TOKEN",
help=HELP_TEXT["slack_bot_token"],
),
],
api_url: Annotated[
str,
typer.Option(
prompt=API_URL_PROPMPT, envvar="REGINALD_API_URL", help=HELP_TEXT["api_url"]
prompt=PROMPTS["api_url"],
envvar="REGINALD_API_URL",
help=HELP_TEXT["api_url"],
),
],
emoji: Annotated[
Expand All @@ -149,6 +191,8 @@ def bot(
set_up_logging_config(level=20)
main(
cli="bot",
slack_app_token=slack_app_token,
slack_bot_token=slack_bot_token,
api_url=api_url,
emoji=emoji,
)
Expand Down
22 changes: 22 additions & 0 deletions reginald/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pathlib
from typing import Final

EMOJI_DEFAULT: Final[str] = "rocket"

LLAMA_INDEX_DIR: Final[str] = "llama_index_indices"

DEFAULT_ARGS = {
"model": "hello",
"mode": "chat",
"data_dir": pathlib.Path(__file__).parent.parent / "data",
"which_index": "reg",
"force_new_index": False,
"max_input_size": 4096,
"k": 3,
"chunk_size": 512,
"chunk_overlap_ratio": 0.1,
"num_output": 512,
"is_path": False,
"n_gpu_layers": 0,
"device": "auto",
}
75 changes: 75 additions & 0 deletions reginald/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
class ModelMapper:
@staticmethod
def available_models():
return [
"chat-completion-azure",
"chat-completion-openai",
"hello",
"llama-index-ollama",
"llama-index-llama-cpp",
"llama-index-hf",
"llama-index-gpt-azure",
"llama-index-gpt-openai",
]

@staticmethod
def get_model(model_name: str):
match model_name:
case "chat-completion-azure":
from reginald.models.simple.chat_completion import ChatCompletionAzure

return ChatCompletionAzure
case "chat-completion-openai":
from reginald.models.simple.chat_completion import ChatCompletionOpenAI

return ChatCompletionOpenAI
case "hello":
from reginald.models.simple.hello import Hello

return Hello
case "llama-index-ollama":
from reginald.models.llama_index.llama_index_ollama import (
LlamaIndexOllama,
)

return LlamaIndexOllama
case "llama-index-llama-cpp":
from reginald.models.llama_index.llama_index_llama_cpp import (
LlamaIndexLlamaCPP,
)

return LlamaIndexLlamaCPP
case "llama-index-hf":
from reginald.models.llama_index.llama_index_hf import LlamaIndexHF

return LlamaIndexHF
case "llama-index-gpt-azure":
from reginald.models.llama_index.llama_index_openai import (
LlamaIndexGPTOpenAI,
)

return LlamaIndexGPTOpenAI
case "llama-index-gpt-openai":
from reginald.models.llama_index.llama_index_openai import (
LlamaIndexGPTOpenAI,
)

return LlamaIndexGPTOpenAI
case _:
raise ValueError(
f"Model {model_name} not found. Available models: {ModelMapper.available_models()}"
)


DEFAULTS = {
"chat-completion-azure": "reginald-gpt4",
"chat-completion-openai": "gpt-3.5-turbo",
"hello": None,
"llama-index-ollama": "llama3",
"llama-index-llama-cpp": "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve/main/llama-2-13b-chat.Q6_K.gguf",
"llama-index-hf": "microsoft/phi-1_5",
"llama-index-gpt-azure": "reginald-gpt4",
"llama-index-gpt-openai": "gpt-3.5-turbo",
}

__all__ = ["MODELS", "DEFAULTS"]
10 changes: 10 additions & 0 deletions reginald/models/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel

from reginald.models.setup_llm import setup_llm


class Query(BaseModel):
message: str
Expand Down Expand Up @@ -38,3 +41,10 @@ async def channel_mention(query: Query):
return response

return app


async def run_reginald_app(**kwargs) -> None:
# set up response model
response_model = setup_llm(**kwargs)
app: FastAPI = create_reginald_app(response_model)
uvicorn.run(app, host="0.0.0.0", port=8000)
File renamed without changes.
44 changes: 44 additions & 0 deletions reginald/models/chat_interact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import readline

from reginald.models.base import ResponseModel
from reginald.models.setup_llm import setup_llm

art = """
(`-') (`-') _ _ <-. (`-')_(`-') _ _(`-')
<-.(OO ) ( OO).-/ .-> (_) \( OO) (OO ).-/ <-. ( (OO ).->
,------,(,------.,---(`-'),-(`-',--./ ,--// ,---. ,--. ) \ .'_
| /`. '| .---' .-(OO )| ( OO| \ | || \ /`.\ | (`-''`'-..__)
| |_.' (| '--.| | .-, \| | | . '| |'-'|_.' || |OO | | ' |
| . .'| .--'| | '.(_(| |_/| |\ (| .-. (| '__ | | / :
| |\ \ | `---| '-' | | |'-| | \ || | | || || '-' /
`--' '--'`------'`-----' `--' `--' `--'`--' `--'`-----'`------'
"""


def run_chat_interact(streaming: bool = False, **kwargs) -> ResponseModel:
# set up response model
response_model = setup_llm(**kwargs)
user_id = "command_line_chat"
print(art)

while True:
message = input(">>> ")
if message in ["exit", "exit()", "quit()", "bye Reginald"]:
return response_model
if message in ["clear_history", "\clear_history"]:
if (
response_model.mode == "chat"
and response_model.chat_engine.get(user_id) is not None
):
response_model.chat_engine[user_id].reset()
print("\nReginald: History cleared.")
else:
print("\nReginald: No history to clear.")
continue

if streaming:
response = response_model.stream_message(message=message, user_id=user_id)
print("")
else:
response = response_model.direct_message(message=message, user_id=user_id)
print(f"\nReginald: {response.message}")
6 changes: 3 additions & 3 deletions reginald/models/create_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from llama_index.core.llms.callbacks import llm_completion_callback
from llama_index.core.llms.custom import CustomLLM

from reginald.models.models.llama_index import (
from reginald.defaults import DEFAULT_ARGS
from reginald.models.llama_index.base import (
DataIndexCreator,
compute_default_chunk_size,
setup_settings,
)
from reginald.models.setup_llm import DEFAULT_ARGS
from reginald.models.llama_index.llama_utils import setup_settings


class DummyLLM(CustomLLM):
Expand Down
65 changes: 65 additions & 0 deletions reginald/models/download_from_fileshare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import logging
import pathlib
import sys

from reginald.utils import create_folder


def download_from_fileshare(
data_dir: pathlib.Path | str,
which_index: str,
azure_storage_key: str | None,
connection_str: str | None,
) -> None:
from azure.storage.fileshare import ShareClient
from tqdm import tqdm

if azure_storage_key is None:
logging.error("azure_storage_key is not set.")
sys.exit(1)
if connection_str is None:
logging.error("connection_str is not set.")
sys.exit(1)

# set the file share name and directory
file_share_name = "llama-data"
file_share_directory = f"llama_index_indices/{which_index}"

# create a ShareClient object
share_client = ShareClient.from_connection_string(
conn_str=connection_str,
share_name=file_share_name,
credential=azure_storage_key,
)

# get a reference to the file share directory
file_share_directory_client = share_client.get_directory_client(
file_share_directory
)

# set the local download directory
local_download_directory = (
pathlib.Path(data_dir) / "llama_index_indices" / which_index
)

# create folder if does not exist
create_folder(local_download_directory)

# list all the files in the directory
files_list = file_share_directory_client.list_directories_and_files()

# check if the index exists
try:
files_list = list(files_list)
except:
logging.error(f"Index {which_index} does not exist in the file share")
sys.exit(1)

# iterate through each file in the list and download it
for file in tqdm(files_list):
if not file.is_directory:
file_client = file_share_directory_client.get_file_client(file.name)
download_path = local_download_directory / file.name
with open(download_path, "wb") as file_handle:
data = file_client.download_file()
data.readinto(file_handle)
Empty file.
Loading

0 comments on commit 0d2737b

Please sign in to comment.