-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor so models and modules loaded only when necessary
- Loading branch information
Showing
33 changed files
with
1,939 additions
and
1,627 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import pathlib | ||
from typing import Final | ||
|
||
EMOJI_DEFAULT: Final[str] = "rocket" | ||
|
||
LLAMA_INDEX_DIR: Final[str] = "llama_index_indices" | ||
|
||
DEFAULT_ARGS = { | ||
"model": "hello", | ||
"mode": "chat", | ||
"data_dir": pathlib.Path(__file__).parent.parent / "data", | ||
"which_index": "reg", | ||
"force_new_index": False, | ||
"max_input_size": 4096, | ||
"k": 3, | ||
"chunk_size": 512, | ||
"chunk_overlap_ratio": 0.1, | ||
"num_output": 512, | ||
"is_path": False, | ||
"n_gpu_layers": 0, | ||
"device": "auto", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
class ModelMapper: | ||
@staticmethod | ||
def available_models(): | ||
return [ | ||
"chat-completion-azure", | ||
"chat-completion-openai", | ||
"hello", | ||
"llama-index-ollama", | ||
"llama-index-llama-cpp", | ||
"llama-index-hf", | ||
"llama-index-gpt-azure", | ||
"llama-index-gpt-openai", | ||
] | ||
|
||
@staticmethod | ||
def get_model(model_name: str): | ||
match model_name: | ||
case "chat-completion-azure": | ||
from reginald.models.simple.chat_completion import ChatCompletionAzure | ||
|
||
return ChatCompletionAzure | ||
case "chat-completion-openai": | ||
from reginald.models.simple.chat_completion import ChatCompletionOpenAI | ||
|
||
return ChatCompletionOpenAI | ||
case "hello": | ||
from reginald.models.simple.hello import Hello | ||
|
||
return Hello | ||
case "llama-index-ollama": | ||
from reginald.models.llama_index.llama_index_ollama import ( | ||
LlamaIndexOllama, | ||
) | ||
|
||
return LlamaIndexOllama | ||
case "llama-index-llama-cpp": | ||
from reginald.models.llama_index.llama_index_llama_cpp import ( | ||
LlamaIndexLlamaCPP, | ||
) | ||
|
||
return LlamaIndexLlamaCPP | ||
case "llama-index-hf": | ||
from reginald.models.llama_index.llama_index_hf import LlamaIndexHF | ||
|
||
return LlamaIndexHF | ||
case "llama-index-gpt-azure": | ||
from reginald.models.llama_index.llama_index_openai import ( | ||
LlamaIndexGPTOpenAI, | ||
) | ||
|
||
return LlamaIndexGPTOpenAI | ||
case "llama-index-gpt-openai": | ||
from reginald.models.llama_index.llama_index_openai import ( | ||
LlamaIndexGPTOpenAI, | ||
) | ||
|
||
return LlamaIndexGPTOpenAI | ||
case _: | ||
raise ValueError( | ||
f"Model {model_name} not found. Available models: {ModelMapper.available_models()}" | ||
) | ||
|
||
|
||
DEFAULTS = { | ||
"chat-completion-azure": "reginald-gpt4", | ||
"chat-completion-openai": "gpt-3.5-turbo", | ||
"hello": None, | ||
"llama-index-ollama": "llama3", | ||
"llama-index-llama-cpp": "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve/main/llama-2-13b-chat.Q6_K.gguf", | ||
"llama-index-hf": "microsoft/phi-1_5", | ||
"llama-index-gpt-azure": "reginald-gpt4", | ||
"llama-index-gpt-openai": "gpt-3.5-turbo", | ||
} | ||
|
||
__all__ = ["MODELS", "DEFAULTS"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import readline | ||
|
||
from reginald.models.base import ResponseModel | ||
from reginald.models.setup_llm import setup_llm | ||
|
||
art = """ | ||
(`-') (`-') _ _ <-. (`-')_(`-') _ _(`-') | ||
<-.(OO ) ( OO).-/ .-> (_) \( OO) (OO ).-/ <-. ( (OO ).-> | ||
,------,(,------.,---(`-'),-(`-',--./ ,--// ,---. ,--. ) \ .'_ | ||
| /`. '| .---' .-(OO )| ( OO| \ | || \ /`.\ | (`-''`'-..__) | ||
| |_.' (| '--.| | .-, \| | | . '| |'-'|_.' || |OO | | ' | | ||
| . .'| .--'| | '.(_(| |_/| |\ (| .-. (| '__ | | / : | ||
| |\ \ | `---| '-' | | |'-| | \ || | | || || '-' / | ||
`--' '--'`------'`-----' `--' `--' `--'`--' `--'`-----'`------' | ||
""" | ||
|
||
|
||
def run_chat_interact(streaming: bool = False, **kwargs) -> ResponseModel: | ||
# set up response model | ||
response_model = setup_llm(**kwargs) | ||
user_id = "command_line_chat" | ||
print(art) | ||
|
||
while True: | ||
message = input(">>> ") | ||
if message in ["exit", "exit()", "quit()", "bye Reginald"]: | ||
return response_model | ||
if message in ["clear_history", "\clear_history"]: | ||
if ( | ||
response_model.mode == "chat" | ||
and response_model.chat_engine.get(user_id) is not None | ||
): | ||
response_model.chat_engine[user_id].reset() | ||
print("\nReginald: History cleared.") | ||
else: | ||
print("\nReginald: No history to clear.") | ||
continue | ||
|
||
if streaming: | ||
response = response_model.stream_message(message=message, user_id=user_id) | ||
print("") | ||
else: | ||
response = response_model.direct_message(message=message, user_id=user_id) | ||
print(f"\nReginald: {response.message}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import logging | ||
import pathlib | ||
import sys | ||
|
||
from reginald.utils import create_folder | ||
|
||
|
||
def download_from_fileshare( | ||
data_dir: pathlib.Path | str, | ||
which_index: str, | ||
azure_storage_key: str | None, | ||
connection_str: str | None, | ||
) -> None: | ||
from azure.storage.fileshare import ShareClient | ||
from tqdm import tqdm | ||
|
||
if azure_storage_key is None: | ||
logging.error("azure_storage_key is not set.") | ||
sys.exit(1) | ||
if connection_str is None: | ||
logging.error("connection_str is not set.") | ||
sys.exit(1) | ||
|
||
# set the file share name and directory | ||
file_share_name = "llama-data" | ||
file_share_directory = f"llama_index_indices/{which_index}" | ||
|
||
# create a ShareClient object | ||
share_client = ShareClient.from_connection_string( | ||
conn_str=connection_str, | ||
share_name=file_share_name, | ||
credential=azure_storage_key, | ||
) | ||
|
||
# get a reference to the file share directory | ||
file_share_directory_client = share_client.get_directory_client( | ||
file_share_directory | ||
) | ||
|
||
# set the local download directory | ||
local_download_directory = ( | ||
pathlib.Path(data_dir) / "llama_index_indices" / which_index | ||
) | ||
|
||
# create folder if does not exist | ||
create_folder(local_download_directory) | ||
|
||
# list all the files in the directory | ||
files_list = file_share_directory_client.list_directories_and_files() | ||
|
||
# check if the index exists | ||
try: | ||
files_list = list(files_list) | ||
except: | ||
logging.error(f"Index {which_index} does not exist in the file share") | ||
sys.exit(1) | ||
|
||
# iterate through each file in the list and download it | ||
for file in tqdm(files_list): | ||
if not file.is_directory: | ||
file_client = file_share_directory_client.get_file_client(file.name) | ||
download_path = local_download_directory / file.name | ||
with open(download_path, "wb") as file_handle: | ||
data = file_client.download_file() | ||
data.readinto(file_handle) |
Empty file.
Oops, something went wrong.