From 2c29ffae7c1e51179b62cea768d2d090c7d1c99e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 2 Aug 2023 15:01:19 -0700 Subject: [PATCH 1/3] adding support for anthropic and cohere --- gpyt/__init__.py | 4 ++ gpyt/lite_assistant.py | 88 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 gpyt/lite_assistant.py diff --git a/gpyt/__init__.py b/gpyt/__init__.py index 27ab325..53c93e5 100644 --- a/gpyt/__init__.py +++ b/gpyt/__init__.py @@ -13,12 +13,16 @@ # check for environment variable first API_KEY = os.getenv("OPENAI_API_KEY") PALM_API_KEY = os.getenv("PALM_API_KEY") +ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") +COHERE_API_KEY = os.getenv("COHERE_API_KEY") DOTENV_PATH = os.path.expanduser("~/.env") if API_KEY is None or (isinstance(API_KEY, str) and not len(API_KEY)): result = dotenv_values(DOTENV_PATH) API_KEY = result.get("OPENAI_API_KEY", None) PALM_API_KEY = result.get("PALM_API_KEY", None) + ANTHROPIC_API_KEY = result.get("ANTHROPIC_API_KEY", None) + COHERE_API_KEY = result.get("COHERE_API_KEY", None) assert ( diff --git a/gpyt/lite_assistant.py b/gpyt/lite_assistant.py new file mode 100644 index 0000000..2abd908 --- /dev/null +++ b/gpyt/lite_assistant.py @@ -0,0 +1,88 @@ +from typing import Generator +import litellm +from litellm import completion +import google.generativeai as palm +import tiktoken +import os +from .assistant import Assistant +from .config import API_ERROR_FALLBACK, PROMPT +import dotenv +dotenv.load_dotenv() +from .config import ( + API_ERROR_FALLBACK, + APPROX_PROMPT_TOKEN_USAGE, +) + + +class LiteAssistant(Assistant): + API_ERROR_MESSAGE = """There was an ERROR with the model API\n## Diagnostics\n* Make sure you have a PALM_API_KEY set at `~/.env`\n* Make sure you aren't being rate-limited.""" + + def __init__(self, api_key: str, model: str, memory: bool = True): + self.error_fallback_message = API_ERROR_FALLBACK + self.chat: list[dict[str, str]] = [] + self.model = model + self.api_key = api_key + self.memory = memory + self.messages = [] + self.error_fallback_message = LiteAssistant.API_ERROR_MESSAGE + self.input_tokens_this_convo = APPROX_PROMPT_TOKEN_USAGE + self.output_tokens_this_convo = 10 + self.prompt, self.summary_prompt = ("", "") + + self._encoding_engine = tiktoken.get_encoding( + "cl100k_base" + ) + self.price_of_this_convo = self.get_default_price_of_prompt() + + def set_history(self, new_history: list[dict[str, str]]): + self.clear_history() + for message in new_history: + self.messages.append(message["content"]) + + def clear_history(self) -> None: + """reset internal message queue""" + self.messages.clear() + + def get_response_stream(self, user_input: str) -> Generator: + """Fake a stream output with PaLM 2""" + response = self.get_response(user_input) + for i in range(0, len(response), 8): + yield response[i : i + 8] + + def log_assistant_response(self, final_response: str) -> None: # pyright: ignore + ... + + def get_response(self, user_input: str) -> str: + if not self.memory: + self.clear_history() + self.messages.append({"role": "user", "content": user_input}) + + response = completion( # type: ignore + model=self.model, + messages=self.messages, + stream=True, + api_key=self.api_key # optional param, litellm will read it from the os.getenv by default + ) + + return response # type: ignore + + def _bad_key(self) -> bool: + return not self.api_key or (type(self.api_key) is str and len(self.api_key) < 1) + + def get_conversation_summary(self, initial_message: str) -> str: + """Generate a short 6 word or less summary of the user\'s first message""" + try: + summary = self.get_response(initial_message) + except: + return Assistant.kDEFAULT_SUMMARY_FALLTHROUGH + + return summary + +if __name__ == "__main__": + ... + agent = LiteAssistant(model="claude-instant-1") + print(agent.get_conversation_summary("What is the third closest moon from Saturn")) + while 1: + user_inp = input("> ") + msg = agent.get_response(user_input=user_inp) + print("\n", msg) From 70a021452372887bc8b35307d85bea27c8a2dbca Mon Sep 17 00:00:00 2001 From: Justin Stitt Date: Sun, 6 Aug 2023 09:18:42 -0700 Subject: [PATCH 2/3] needs review did some patchwork to get litellm hooked up --- .python-version | 1 + gpyt/__init__.py | 4 +- gpyt/args.py | 10 + gpyt/components/assistant_app.py | 23 ++- gpyt/lite_assistant.py | 25 +-- gpyt/styles.cssx | 10 +- litellm_uuid.txt | 1 + poetry.lock | 324 ++++++++++++++++++++++++++----- pyproject.toml | 1 + 9 files changed, 327 insertions(+), 72 deletions(-) create mode 100644 .python-version create mode 100644 litellm_uuid.txt diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..afad818 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.11.0 diff --git a/gpyt/__init__.py b/gpyt/__init__.py index 53c93e5..3aba86c 100644 --- a/gpyt/__init__.py +++ b/gpyt/__init__.py @@ -4,6 +4,7 @@ from gpyt.free_assistant import FreeAssistant from gpyt.palm_assistant import PalmAssistant +from gpyt.lite_assistant import LiteAssistant from .app import gpyt from .args import USE_EXPERIMENTAL_FREE_MODEL @@ -56,4 +57,5 @@ gpt4 = Assistant(api_key=API_KEY or "", model="gpt-4", prompt=PROMPT) free_gpt = FreeAssistant() palm = PalmAssistant(api_key=PALM_API_KEY) -app = gpyt(assistant=gpt, free_assistant=free_gpt, palm=palm, gpt4=gpt4) +litellm = LiteAssistant(api_key='', model=MODEL) +app = gpyt(assistant=gpt, free_assistant=free_gpt, palm=palm, gpt4=gpt4, lite_assistant = litellm) diff --git a/gpyt/args.py b/gpyt/args.py index edb01a7..19322d9 100644 --- a/gpyt/args.py +++ b/gpyt/args.py @@ -1,9 +1,11 @@ import argparse parser = argparse.ArgumentParser() + parser.add_argument( "--free", help="Use the gpt4free model. (experimental)", action="store_true" ) + parser.add_argument( "--palm", help="Use PaLM 2, Google's Premier LLM (new)", action="store_true" ) @@ -14,9 +16,17 @@ action="store_true", ) +parser.add_argument( + "--litellm", + help="--litellm . See https://github.com/BerriAI/litellm", +) + args = parser.parse_args() USE_EXPERIMENTAL_FREE_MODEL = args.free USE_PALM_MODEL = args.palm USE_GPT4 = args.gpt4 +USE_LITELLM = len(args.litellm) > 0 +LITELLM_MODEL = args.litellm if USE_LITELLM else "" +print(f"{USE_LITELLM=}, {LITELLM_MODEL=}") diff --git a/gpyt/components/assistant_app.py b/gpyt/components/assistant_app.py index 847d27c..c318f5d 100644 --- a/gpyt/components/assistant_app.py +++ b/gpyt/components/assistant_app.py @@ -8,9 +8,16 @@ from textual.widgets import Footer, Header, LoadingIndicator from gpyt.free_assistant import FreeAssistant +from gpyt.lite_assistant import LiteAssistant from gpyt.palm_assistant import PalmAssistant -from ..args import USE_EXPERIMENTAL_FREE_MODEL, USE_GPT4, USE_PALM_MODEL +from ..args import ( + USE_EXPERIMENTAL_FREE_MODEL, + USE_GPT4, + USE_PALM_MODEL, + USE_LITELLM, + LITELLM_MODEL, +) from ..assistant import Assistant from ..conversation import Conversation, Message from ..id import get_id @@ -49,30 +56,38 @@ def __init__( free_assistant: FreeAssistant, palm: PalmAssistant, gpt4: Assistant, + lite_assistant: LiteAssistant, ): super().__init__() self.assistant = assistant self._free_assistant = free_assistant self._palm = palm self._gpt4 = gpt4 + self._lite_assistant = lite_assistant self.conversations: list[Conversation] = [] self.active_conversation: Conversation | None = None self._convo_ids_added: set[str] = set() self.use_free_gpt = USE_EXPERIMENTAL_FREE_MODEL self.use_palm = USE_PALM_MODEL self.use_gpt4 = USE_GPT4 + self.use_litellm = USE_LITELLM self.use_default_model = not ( - self.use_free_gpt or self.use_palm or self.use_gpt4 + self.use_free_gpt or self.use_palm or self.use_gpt4 or self.use_litellm ) self.scrolled_during_response_stream = False - def _get_assistant(self) -> Assistant | FreeAssistant | PalmAssistant: + def _get_assistant( + self, + ) -> Assistant | FreeAssistant | PalmAssistant | LiteAssistant: if self.use_palm: return self._palm if self.use_free_gpt: return self._free_assistant if self.use_gpt4: return self._gpt4 + if self.use_litellm: + self._lite_assistant.model = LITELLM_MODEL + return self._lite_assistant return self.assistant @@ -85,6 +100,8 @@ def adjust_model_border_title(self) -> None: model = "PaLM 2 🌴" elif self.use_gpt4: model = "GPT 4" + elif self.use_litellm: + model = f"LITELLM w/ {LITELLM_MODEL}" self.user_input.border_title = f"Model: {model}" diff --git a/gpyt/lite_assistant.py b/gpyt/lite_assistant.py index 2abd908..15fedd2 100644 --- a/gpyt/lite_assistant.py +++ b/gpyt/lite_assistant.py @@ -1,21 +1,18 @@ from typing import Generator -import litellm + +import dotenv from litellm import completion -import google.generativeai as palm import tiktoken -import os + from .assistant import Assistant -from .config import API_ERROR_FALLBACK, PROMPT -import dotenv +from .config import API_ERROR_FALLBACK +from .config import API_ERROR_FALLBACK, APPROX_PROMPT_TOKEN_USAGE + dotenv.load_dotenv() -from .config import ( - API_ERROR_FALLBACK, - APPROX_PROMPT_TOKEN_USAGE, -) class LiteAssistant(Assistant): - API_ERROR_MESSAGE = """There was an ERROR with the model API\n## Diagnostics\n* Make sure you have a PALM_API_KEY set at `~/.env`\n* Make sure you aren't being rate-limited.""" + API_ERROR_MESSAGE = """There was an ERROR with the model API\n## Diagnostics\n* Make sure you have an API_KEY set at `~/.env`\n* Make sure you aren't being rate-limited.""" def __init__(self, api_key: str, model: str, memory: bool = True): self.error_fallback_message = API_ERROR_FALLBACK @@ -29,9 +26,7 @@ def __init__(self, api_key: str, model: str, memory: bool = True): self.output_tokens_this_convo = 10 self.prompt, self.summary_prompt = ("", "") - self._encoding_engine = tiktoken.get_encoding( - "cl100k_base" - ) + self._encoding_engine = tiktoken.get_encoding("cl100k_base") self.price_of_this_convo = self.get_default_price_of_prompt() def set_history(self, new_history: list[dict[str, str]]): @@ -61,7 +56,6 @@ def get_response(self, user_input: str) -> str: model=self.model, messages=self.messages, stream=True, - api_key=self.api_key # optional param, litellm will read it from the os.getenv by default ) return response # type: ignore @@ -75,9 +69,10 @@ def get_conversation_summary(self, initial_message: str) -> str: summary = self.get_response(initial_message) except: return Assistant.kDEFAULT_SUMMARY_FALLTHROUGH - + return summary + if __name__ == "__main__": ... agent = LiteAssistant(model="claude-instant-1") diff --git a/gpyt/styles.cssx b/gpyt/styles.cssx index f74dad3..e587415 100644 --- a/gpyt/styles.cssx +++ b/gpyt/styles.cssx @@ -109,10 +109,8 @@ Options.hidden { offset-x: -100%; } -OptionCheckbox { - height: 1; -} - -OptionCheckbox:focus-within { - background: $secondary-background-lighten-2; +RadioButton { + padding: 0; + height: 3; + margin: 0; } diff --git a/litellm_uuid.txt b/litellm_uuid.txt new file mode 100644 index 0000000..8ba5819 --- /dev/null +++ b/litellm_uuid.txt @@ -0,0 +1 @@ +720218e6-e9f5-4c0d-9024-57ed6266b938 \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 95d794d..02bb29b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -147,6 +147,26 @@ toolz = "*" [package.extras] dev = ["black", "docutils", "flake8", "ipython", "m2r", "mistune (<2.0.0)", "pytest", "recommonmark", "sphinx", "vega-datasets"] +[[package]] +name = "anthropic" +version = "0.3.8" +description = "Client library for the anthropic API" +category = "main" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "anthropic-0.3.8-py3-none-any.whl", hash = "sha256:97ffe1bacc4214dc89b19f496cf2769746971e86f7c835a05aa21b76f260d279"}, + {file = "anthropic-0.3.8.tar.gz", hash = "sha256:6651099807456c3b95b3879f5ad7d00f7e7e4f7649a2394d18032ab8be54ef16"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<4" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<2.0.0" +tokenizers = ">=0.13.0" +typing-extensions = ">=4.1.1,<5" + [[package]] name = "anyio" version = "3.7.1" @@ -199,6 +219,18 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib- tests = ["attrs[tests-no-zope]", "zope-interface"] tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +[[package]] +name = "backoff" +version = "2.2.1" +description = "Function decoration for backoff and retry" +category = "main" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8"}, + {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, +] + [[package]] name = "beautifulsoup4" version = "4.12.2" @@ -431,6 +463,26 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "cohere" +version = "4.19.2" +description = "" +category = "main" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "cohere-4.19.2-py3-none-any.whl", hash = "sha256:0b6a4fe04380a481a8e975ebcc9bb6433febe4d3eb583b6d6e04342a5e998345"}, + {file = "cohere-4.19.2.tar.gz", hash = "sha256:a0b0fa698b3d3983fb328bb90d68fcf08faaa2268f3772ebc6bfea6ba55acf27"}, +] + +[package.dependencies] +aiohttp = ">=3.0,<4.0" +backoff = ">=2.0,<3.0" +fastavro = {version = "1.8.2", markers = "python_version >= \"3.8\""} +importlib_metadata = ">=6.0,<7.0" +requests = ">=2.25.0,<3.0.0" +urllib3 = ">=1.26,<3" + [[package]] name = "colorama" version = "0.4.6" @@ -479,6 +531,18 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] +[[package]] +name = "distro" +version = "1.8.0" +description = "Distro - an OS platform information API" +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.8.0-py3-none-any.whl", hash = "sha256:99522ca3e365cac527b44bde033f64c6945d90eb9f769703caaec52b09bbd3ff"}, + {file = "distro-1.8.0.tar.gz", hash = "sha256:02e111d1dc6a50abb8eed6bf31c3e48ed8b0830d1ea2a1b78c61765c2513fdd8"}, +] + [[package]] name = "entrypoints" version = "0.4" @@ -491,6 +555,18 @@ files = [ {file = "entrypoints-0.4.tar.gz", hash = "sha256:b706eddaa9218a19ebcd67b56818f05bb27589b1ca9e8d797b74affad4ccacd4"}, ] +[[package]] +name = "et-xmlfile" +version = "1.1.0" +description = "An implementation of lxml.xmlfile for the standard library" +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "et_xmlfile-1.1.0-py3-none-any.whl", hash = "sha256:a2ba85d1d6a74ef63837eed693bcb89c3f752169b0e3e7ae5b16ca5e1b3deada"}, + {file = "et_xmlfile-1.1.0.tar.gz", hash = "sha256:8eb9e2bc2f8c97e37a2dc85a09ecdcdec9d8a396530a6d5a33b30b9a92da0c5c"}, +] + [[package]] name = "exceptiongroup" version = "1.1.2" @@ -518,6 +594,47 @@ files = [ {file = "fake_useragent-1.1.3-py3-none-any.whl", hash = "sha256:695d3b1bf7d11d04ab0f971fb73b0ca8de98b78bbadfbc8bacbc9a48423f7531"}, ] +[[package]] +name = "fastavro" +version = "1.8.2" +description = "Fast read/write of AVRO files" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fastavro-1.8.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:0e08964b2e9a455d831f2557402a683d4c4d45206f2ab9ade7c69d3dc14e0e58"}, + {file = "fastavro-1.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:401a70b1e5c7161420c6019e0c8afa88f7c8a373468591f5ec37639a903c2509"}, + {file = "fastavro-1.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eef1ed3eaa4240c05698d02d8d0c010b9a03780eda37b492da6cd4c9d37e04ec"}, + {file = "fastavro-1.8.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:543185a672ff6306beb329b57a7b8a3a2dd1eb21a5ccc530150623d58d48bb98"}, + {file = "fastavro-1.8.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ffbf8bae1edb50fe7beeffc3afa8e684686550c2e5d31bf01c25cfa213f581e1"}, + {file = "fastavro-1.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:bb545eb9d876bc7b785e27e98e7720ada7eee7d7a1729798d2ed51517f13500a"}, + {file = "fastavro-1.8.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2b837d3038c651046252bc92c1b9899bf21c7927a148a1ff89599c36c2a331ca"}, + {file = "fastavro-1.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3510e96c0a47e4e914bd1a29c954eb662bfa24849ad92e597cb97cc79f21af7"}, + {file = "fastavro-1.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccc0e74f2c2ab357f39bb73d67fcdb6dc10e23fdbbd399326139f72ec0fb99a3"}, + {file = "fastavro-1.8.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:add51c70d0ab1175601c75cd687bbe9d16ae312cd8899b907aafe0d79ee2bc1d"}, + {file = "fastavro-1.8.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d9e2662f57e6453e9a2c9fb4f54b2a9e62e3e46f5a412ac00558112336d23883"}, + {file = "fastavro-1.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:fea75cf53a93c56dd56e68abce8d314ef877b27451c870cd7ede7582d34c08a7"}, + {file = "fastavro-1.8.2-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:f489020bb8664c2737c03457ad5dbd490579ddab6f0a7b5c17fecfe982715a89"}, + {file = "fastavro-1.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a547625c138efd5e61300119241041906ee8cb426fc7aa789900f87af7ed330d"}, + {file = "fastavro-1.8.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53beb458f30c9ad4aa7bff4a42243ff990ffb713b6ce0cd9b360cbc3d648fe52"}, + {file = "fastavro-1.8.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7b1b2cbd2dd851452306beed0ab9bdaeeab1cc8ad46f84b47cd81eeaff6dd6b8"}, + {file = "fastavro-1.8.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d29e9baee0b2f37ecd09bde3b487cf900431fd548c85be3e4fe1b9a0b2a917f1"}, + {file = "fastavro-1.8.2-cp38-cp38-win_amd64.whl", hash = "sha256:66e132c710663230292bc63e2cb79cf95b16ccb94a5fc99bb63694b24e312fc5"}, + {file = "fastavro-1.8.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:38aca63ce604039bcdf2edd14912d00287bdbf8b76f9aa42b28e6ca0bf950092"}, + {file = "fastavro-1.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9787835f6449ee94713e7993a700432fce3763024791ffa8a58dc91ef9d1f950"}, + {file = "fastavro-1.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:536cb448bc83811056be02749fd9df37a69621678f02597d272970a769e9b40c"}, + {file = "fastavro-1.8.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e9d5027cf7d9968f8f819958b41bfedb933323ea6d6a0485eefacaa1afd91f54"}, + {file = "fastavro-1.8.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:792adfc0c80c7f1109e0ab4b0decef20691fdf0a45091d397a0563872eb56d42"}, + {file = "fastavro-1.8.2-cp39-cp39-win_amd64.whl", hash = "sha256:650b22766259f7dd7519dfa4e4658f0e233c319efa130b9cf0c36a500e09cc57"}, + {file = "fastavro-1.8.2.tar.gz", hash = "sha256:ab9d9226d4b66b6b3d0661a57cd45259b0868fed1c0cd4fac95249b9e0973320"}, +] + +[package.extras] +codecs = ["lz4", "python-snappy", "zstandard"] +lz4 = ["lz4"] +snappy = ["python-snappy"] +zstandard = ["zstandard"] + [[package]] name = "frozenlist" version = "1.4.0" @@ -742,32 +859,6 @@ pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] reauth = ["pyu2f (>=0.1.5)"] requests = ["requests (>=2.20.0,<3.0.0dev)"] -[[package]] -name = "google-auth" -version = "2.22.0" -description = "Google Authentication Library" -category = "main" -optional = false -python-versions = ">=3.6" -files = [ - {file = "google-auth-2.22.0.tar.gz", hash = "sha256:164cba9af4e6e4e40c3a4f90a1a6c12ee56f14c0b4868d1ca91b32826ab334ce"}, - {file = "google_auth-2.22.0-py2.py3-none-any.whl", hash = "sha256:d61d1b40897407b574da67da1a833bdc10d5a11642566e506565d1b1a46ba873"}, -] - -[package.dependencies] -cachetools = ">=2.0.0,<6.0" -pyasn1-modules = ">=0.2.1" -rsa = ">=3.1.4,<5" -six = ">=1.9.0" -urllib3 = "<2.0" - -[package.extras] -aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"] -enterprise-cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"] -pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] -reauth = ["pyu2f (>=0.1.5)"] -requests = ["requests (>=2.20.0,<3.0.0.dev0)"] - [[package]] name = "google-generativeai" version = "0.1.0" @@ -1084,6 +1175,28 @@ dev = ["black", "flake8", "isort", "pre-commit", "pyproject-flake8"] doc = ["myst-parser", "sphinx", "sphinx-book-theme"] test = ["coverage", "pytest", "pytest-cov"] +[[package]] +name = "litellm" +version = "0.1.351" +description = "Library to easily interface with LLM API providers" +category = "main" +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "litellm-0.1.351-py3-none-any.whl", hash = "sha256:92669c9851765d13e898b2f6286eb40e04fde3d485bd403d8a766ec310926b2b"}, + {file = "litellm-0.1.351.tar.gz", hash = "sha256:637cc69aae5de8cb626c52150d49f5424612146d4a817a3d5013721fbf3bf24d"}, +] + +[package.dependencies] +anthropic = ">=0.3.7,<0.4.0" +cohere = ">=4.18.0,<5.0.0" +openai = {version = ">=0.27.8,<0.28.0", extras = ["datalib"]} +pytest = ">=7.4.0,<8.0.0" +python-dotenv = ">=1.0.0,<2.0.0" +replicate = ">=0.10.0,<0.11.0" +tenacity = ">=8.0.1,<9.0.0" +tiktoken = ">=0.4.0,<0.5.0" + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -1347,6 +1460,10 @@ files = [ [package.dependencies] aiohttp = "*" +numpy = {version = "*", optional = true, markers = "extra == \"datalib\""} +openpyxl = {version = ">=3.0.7", optional = true, markers = "extra == \"datalib\""} +pandas = {version = ">=1.2.3", optional = true, markers = "extra == \"datalib\""} +pandas-stubs = {version = ">=1.1.0.11", optional = true, markers = "extra == \"datalib\""} requests = ">=2.20" tqdm = "*" @@ -1372,6 +1489,21 @@ files = [ httpx = "*" pytest = "*" +[[package]] +name = "openpyxl" +version = "3.1.2" +description = "A Python library to read/write Excel 2010 xlsx/xlsm files" +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "openpyxl-3.1.2-py2.py3-none-any.whl", hash = "sha256:f91456ead12ab3c6c2e9491cf33ba6d08357d802192379bb482f1033ade496f5"}, + {file = "openpyxl-3.1.2.tar.gz", hash = "sha256:a6f5977418eff3b2d5500d54d9db50c8277a368436f4e4f8ddb1be3422870184"}, +] + +[package.dependencies] +et-xmlfile = "*" + [[package]] name = "outcome" version = "1.2.0" @@ -1447,6 +1579,22 @@ pytz = ">=2020.1" [package.extras] test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] +[[package]] +name = "pandas-stubs" +version = "2.0.2.230605" +description = "Type annotations for pandas" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pandas_stubs-2.0.2.230605-py3-none-any.whl", hash = "sha256:39106b602f3cb6dc5f728b84e1b32bde6ecf41ee34ee714c66228009609fbada"}, + {file = "pandas_stubs-2.0.2.230605.tar.gz", hash = "sha256:624c7bb06d38145a44b61be459ccd19b038e0bf20364a025ecaab78fea65e858"}, +] + +[package.dependencies] +numpy = ">=1.24.3" +types-pytz = ">=2022.1.1" + [[package]] name = "pillow" version = "10.0.0" @@ -2020,6 +2168,26 @@ files = [ {file = "regex-2023.6.3.tar.gz", hash = "sha256:72d1a25bf36d2050ceb35b517afe13864865268dfb45910e2e17a84be6cbfeb0"}, ] +[[package]] +name = "replicate" +version = "0.10.0" +description = "Python client for Replicate" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "replicate-0.10.0-py3-none-any.whl", hash = "sha256:71464d62259b22a17f5383e83ebaf158cace69ba0153a9de21061283ac2c3a4b"}, + {file = "replicate-0.10.0.tar.gz", hash = "sha256:f714944be0c65eef7b3ebf740eb132d573a52e811a02f0031b9f1ae077a50696"}, +] + +[package.dependencies] +packaging = "*" +pydantic = ">1" +requests = ">2" + +[package.extras] +dev = ["black", "mypy", "pytest", "responses", "ruff"] + [[package]] name = "requests" version = "2.31.0" @@ -2312,6 +2480,21 @@ watchdog = {version = "*", markers = "platform_system != \"Darwin\""} [package.extras] snowflake = ["snowflake-snowpark-python"] +[[package]] +name = "tenacity" +version = "8.2.2" +description = "Retry code until it succeeds" +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "tenacity-8.2.2-py3-none-any.whl", hash = "sha256:2f277afb21b851637e8f52e6a613ff08734c347dc19ade928e519d7d2d8569b0"}, + {file = "tenacity-8.2.2.tar.gz", hash = "sha256:43af037822bd0029025877f3b2d97cc4d7bb0c2991000a3d59d71517c5c969e0"}, +] + +[package.extras] +doc = ["reno", "sphinx", "tornado (>=4.5)"] + [[package]] name = "textual" version = "0.29.0" @@ -2388,6 +2571,61 @@ files = [ {file = "tls_client-0.2.1.tar.gz", hash = "sha256:473fb4c671d9d4ca6b818548ab6e955640dd589767bfce520830c5618c2f2e2b"}, ] +[[package]] +name = "tokenizers" +version = "0.13.3" +description = "Fast and Customizable Tokenizers" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"}, + {file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5fd1a6a25353e9aa762e2aae5a1e63883cad9f4e997c447ec39d071020459bc"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee0b1b311d65beab83d7a41c56a1e46ab732a9eed4460648e8eb0bd69fc2d059"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ef4215284df1277dadbcc5e17d4882bda19f770d02348e73523f7e7d8b8d396"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4d53976079cff8a033f778fb9adca2d9d69d009c02fa2d71a878b5f3963ed30"}, + {file = "tokenizers-0.13.3-cp310-cp310-win32.whl", hash = "sha256:1f0e3b4c2ea2cd13238ce43548959c118069db7579e5d40ec270ad77da5833ce"}, + {file = "tokenizers-0.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:89649c00d0d7211e8186f7a75dfa1db6996f65edce4b84821817eadcc2d3c79e"}, + {file = "tokenizers-0.13.3-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:56b726e0d2bbc9243872b0144515ba684af5b8d8cd112fb83ee1365e26ec74c8"}, + {file = "tokenizers-0.13.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc5c022ce692e1f499d745af293ab9ee6f5d92538ed2faf73f9708c89ee59ce6"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f55c981ac44ba87c93e847c333e58c12abcbb377a0c2f2ef96e1a266e4184ff2"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f247eae99800ef821a91f47c5280e9e9afaeed9980fc444208d5aa6ba69ff148"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b3e3215d048e94f40f1c95802e45dcc37c5b05eb46280fc2ccc8cd351bff839"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ba2b0bf01777c9b9bc94b53764d6684554ce98551fec496f71bc5be3a03e98b"}, + {file = "tokenizers-0.13.3-cp311-cp311-win32.whl", hash = "sha256:cc78d77f597d1c458bf0ea7c2a64b6aa06941c7a99cb135b5969b0278824d808"}, + {file = "tokenizers-0.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:ecf182bf59bd541a8876deccf0360f5ae60496fd50b58510048020751cf1724c"}, + {file = "tokenizers-0.13.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:0527dc5436a1f6bf2c0327da3145687d3bcfbeab91fed8458920093de3901b44"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cbb2c307627dc99b44b22ef05ff4473aa7c7cc1fec8f0a8b37d8a64b1a16d2"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4560dbdeaae5b7ee0d4e493027e3de6d53c991b5002d7ff95083c99e11dd5ac0"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64064bd0322405c9374305ab9b4c07152a1474370327499911937fd4a76d004b"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8c6e2ab0f2e3d939ca66aa1d596602105fe33b505cd2854a4c1717f704c51de"}, + {file = "tokenizers-0.13.3-cp37-cp37m-win32.whl", hash = "sha256:6cc29d410768f960db8677221e497226e545eaaea01aa3613fa0fdf2cc96cff4"}, + {file = "tokenizers-0.13.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fc2a7fdf864554a0dacf09d32e17c0caa9afe72baf9dd7ddedc61973bae352d8"}, + {file = "tokenizers-0.13.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8791dedba834c1fc55e5f1521be325ea3dafb381964be20684b92fdac95d79b7"}, + {file = "tokenizers-0.13.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:d607a6a13718aeb20507bdf2b96162ead5145bbbfa26788d6b833f98b31b26e1"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3791338f809cd1bf8e4fee6b540b36822434d0c6c6bc47162448deee3f77d425"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2f35f30e39e6aab8716f07790f646bdc6e4a853816cc49a95ef2a9016bf9ce6"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310204dfed5aa797128b65d63538a9837cbdd15da2a29a77d67eefa489edda26"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0f9b92ea052305166559f38498b3b0cae159caea712646648aaa272f7160963"}, + {file = "tokenizers-0.13.3-cp38-cp38-win32.whl", hash = "sha256:9a3fa134896c3c1f0da6e762d15141fbff30d094067c8f1157b9fdca593b5806"}, + {file = "tokenizers-0.13.3-cp38-cp38-win_amd64.whl", hash = "sha256:8e7b0cdeace87fa9e760e6a605e0ae8fc14b7d72e9fc19c578116f7287bb873d"}, + {file = "tokenizers-0.13.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:00cee1e0859d55507e693a48fa4aef07060c4bb6bd93d80120e18fea9371c66d"}, + {file = "tokenizers-0.13.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:a23ff602d0797cea1d0506ce69b27523b07e70f6dda982ab8cf82402de839088"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70ce07445050b537d2696022dafb115307abdffd2a5c106f029490f84501ef97"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:280ffe95f50eaaf655b3a1dc7ff1d9cf4777029dbbc3e63a74e65a056594abc3"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97acfcec592f7e9de8cadcdcda50a7134423ac8455c0166b28c9ff04d227b371"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd7730c98a3010cd4f523465867ff95cd9d6430db46676ce79358f65ae39797b"}, + {file = "tokenizers-0.13.3-cp39-cp39-win32.whl", hash = "sha256:48625a108029cb1ddf42e17a81b5a3230ba6888a70c9dc14e81bc319e812652d"}, + {file = "tokenizers-0.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:bc0a6f1ba036e482db6453571c9e3e60ecd5489980ffd95d11dc9f960483d783"}, + {file = "tokenizers-0.13.3.tar.gz", hash = "sha256:2e546dbb68b623008a5442353137fbb0123d311a6d7ba52f2667c8862a75af2e"}, +] + +[package.extras] +dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] +docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] +testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] + [[package]] name = "toml" version = "0.10.2" @@ -2505,6 +2743,18 @@ files = [ [package.dependencies] requests = "*" +[[package]] +name = "types-pytz" +version = "2023.3.0.0" +description = "Typing stubs for pytz" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "types-pytz-2023.3.0.0.tar.gz", hash = "sha256:ecdc70d543aaf3616a7e48631543a884f74205f284cefd6649ddf44c6a820aac"}, + {file = "types_pytz-2023.3.0.0-py3-none-any.whl", hash = "sha256:4fc2a7fbbc315f0b6630e0b899fd6c743705abe1094d007b0e612d10da15e0f3"}, +] + [[package]] name = "typing-extensions" version = "4.7.1" @@ -2562,26 +2812,6 @@ files = [ [package.extras] test = ["coverage", "pytest", "pytest-cov"] -[[package]] -name = "urllib3" -version = "1.26.16" -description = "HTTP library with thread-safe connection pooling, file post, and more." -category = "main" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" -files = [ - {file = "urllib3-1.26.16-py2.py3-none-any.whl", hash = "sha256:8d36afa7616d8ab714608411b4a3b13e58f463aee519024578e062e141dce20f"}, - {file = "urllib3-1.26.16.tar.gz", hash = "sha256:8f135f6502756bde6b2a9b28989df5fbe87c9970cecaa69041edcce7f0589b14"}, -] - -[package.dependencies] -PySocks = {version = ">=1.5.6,<1.5.7 || >1.5.7,<2.0", optional = true, markers = "extra == \"socks\""} - -[package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] -secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] -socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] - [[package]] name = "urllib3" version = "2.0.4" @@ -2799,4 +3029,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "f64abce5b5cfa8c020495f2df91185982463f0eeb6d0a07b51f816f33b42bd62" +content-hash = "5e69e43bd7a29a0f50223a6e03ef65e4406a452d4cf43fde3e0192d8d517b6f1" diff --git a/pyproject.toml b/pyproject.toml index 07d4a6b..0642a4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ pydantic = "^1.10.7" gpt4free = "^1.0.2" google-generativeai = "^0.1.0rc3" tiktoken = "^0.4.0" +litellm = "^0.1.351" [build-system] From e1a24a813aed214d39202b15ec59fcc372fb66f3 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 7 Aug 2023 13:50:24 -0700 Subject: [PATCH 3/3] adding support for llama2, anthropic, and cohere --- gpyt/__init__.py | 13 ++- gpyt/args.py | 2 +- gpyt/lite_assistant.py | 196 +++++++++++++++++++++++++++++++---------- 3 files changed, 163 insertions(+), 48 deletions(-) diff --git a/gpyt/__init__.py b/gpyt/__init__.py index 3aba86c..52393a8 100644 --- a/gpyt/__init__.py +++ b/gpyt/__init__.py @@ -11,11 +11,13 @@ from .assistant import Assistant from .config import MODEL, PROMPT + # check for environment variable first API_KEY = os.getenv("OPENAI_API_KEY") PALM_API_KEY = os.getenv("PALM_API_KEY") ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") COHERE_API_KEY = os.getenv("COHERE_API_KEY") +REPLICATE_API_KEY = os.getenv("REPLICATE_API_KEY") DOTENV_PATH = os.path.expanduser("~/.env") if API_KEY is None or (isinstance(API_KEY, str) and not len(API_KEY)): @@ -24,6 +26,7 @@ PALM_API_KEY = result.get("PALM_API_KEY", None) ANTHROPIC_API_KEY = result.get("ANTHROPIC_API_KEY", None) COHERE_API_KEY = result.get("COHERE_API_KEY", None) + REPLICATE_API_KEY = result.get("REPLICATE_API_KEY", None) assert ( @@ -51,11 +54,17 @@ generate your own OpenAI API key (see above) """ - +model_keys = {} +if ANTHROPIC_API_KEY: + model_keys["ANTHROPIC_API_KEY"] = ANTHROPIC_API_KEY +if COHERE_API_KEY: + model_keys["COHERE_API_KEY"] = COHERE_API_KEY +if REPLICATE_API_KEY: + model_keys["REPLICATE_API_KEY"] = REPLICATE_API_KEY gpt = Assistant(api_key=API_KEY or "", model=MODEL, prompt=PROMPT) gpt4 = Assistant(api_key=API_KEY or "", model="gpt-4", prompt=PROMPT) free_gpt = FreeAssistant() palm = PalmAssistant(api_key=PALM_API_KEY) -litellm = LiteAssistant(api_key='', model=MODEL) +litellm = LiteAssistant(model_keys=model_keys, model=MODEL, prompt=PROMPT) app = gpyt(assistant=gpt, free_assistant=free_gpt, palm=palm, gpt4=gpt4, lite_assistant = litellm) diff --git a/gpyt/args.py b/gpyt/args.py index 19322d9..c2485cf 100644 --- a/gpyt/args.py +++ b/gpyt/args.py @@ -27,6 +27,6 @@ USE_EXPERIMENTAL_FREE_MODEL = args.free USE_PALM_MODEL = args.palm USE_GPT4 = args.gpt4 -USE_LITELLM = len(args.litellm) > 0 +USE_LITELLM = len(args.litellm) > 0 if hasattr(args, "litellm") and args.litellm is not None else False LITELLM_MODEL = args.litellm if USE_LITELLM else "" print(f"{USE_LITELLM=}, {LITELLM_MODEL=}") diff --git a/gpyt/lite_assistant.py b/gpyt/lite_assistant.py index 15fedd2..2d9d0f1 100644 --- a/gpyt/lite_assistant.py +++ b/gpyt/lite_assistant.py @@ -1,53 +1,123 @@ from typing import Generator -import dotenv +import litellm from litellm import completion +import traceback import tiktoken +import sys, os + +from .config import ( + API_ERROR_FALLBACK, + SUMMARY_PROMPT, + APPROX_PROMPT_TOKEN_USAGE, + PRICING_LOOKUP, + MODEL_MAX_CONTEXT, +) + + +class LiteAssistant: + """ + Represents an OpenAI GPT assistant. + + It can generate responses based on user input and conversation history. + """ + + kDEFAULT_SUMMARY_FALLTHROUGH = "User Question" + + def __init__( + self, + *, + model_keys: dict, + model: str, + prompt: str, + memory: bool = True, + ): + for key in model_keys: + if key == "ANTHROPIC_API_KEY": + litellm.anthropic_key = model_keys[key] + if key == "COHERE_API_KEY": + litellm.cohere_key = model_keys[key] + if key == "REPLICATE_API_KEY": + litellm.replicate_key = model_keys[key] + self.model = model + self.prompt = prompt + self.summary_prompt = SUMMARY_PROMPT + self.memory = memory + self.messages = [ + {"role": "system", "content": self.prompt}, + ] + self.error_fallback_message = API_ERROR_FALLBACK + self._encoding_engine = tiktoken.get_encoding( + "cl100k_base" + ) # gpt3.5/gpt4 encoding engine + self.input_tokens_this_convo = APPROX_PROMPT_TOKEN_USAGE + self.output_tokens_this_convo = 10 + self.price_of_this_convo = self.get_default_price_of_prompt() + + def set_history(self, new_history: list): + """Reassign all message history except for system message""" + sys_prompt = self.messages[0] + self.messages = new_history[::] + self.messages.insert(0, sys_prompt) + + def get_tokens_used(self, message: str) -> int: + return len(self._encoding_engine.encode(message)) + + def get_default_price_of_prompt(self) -> float: + return self.get_approximate_price( + self.get_tokens_used(self.prompt), 0 + ) + self.get_approximate_price(self.get_tokens_used(self.summary_prompt), 10) + + def tokens_used_this_convo(self) -> int: + num = self.input_tokens_this_convo + self.output_tokens_this_convo + has_limit = MODEL_MAX_CONTEXT.get(self.model, None) + if not has_limit: + return 0 + + return num + + def update_token_usage_for_input(self, in_message: str, out_message: str) -> None: + in_tokens_used = self.get_tokens_used(in_message) + out_tokens_used = self.get_tokens_used(out_message) + self.input_tokens_this_convo += self.input_tokens_this_convo + in_tokens_used + self.output_tokens_this_convo += self.output_tokens_this_convo + out_tokens_used + self.price_of_this_convo += self.get_approximate_price( + self.input_tokens_this_convo, self.output_tokens_this_convo + ) -from .assistant import Assistant -from .config import API_ERROR_FALLBACK -from .config import API_ERROR_FALLBACK, APPROX_PROMPT_TOKEN_USAGE + def get_approximate_price(self, _in: int | float, _out: int | float) -> float: + known_pricing = PRICING_LOOKUP.get(self.model, None) + if not known_pricing: + return 0.0 + in_price_ratio = PRICING_LOOKUP[self.model][0] + out_price_ratio = PRICING_LOOKUP[self.model][1] -dotenv.load_dotenv() + limit = MODEL_MAX_CONTEXT.get(self.model, None) + if not limit: + return 0 + _in = min(_in, limit / 2) + _out = min(_out, limit / 2) -class LiteAssistant(Assistant): - API_ERROR_MESSAGE = """There was an ERROR with the model API\n## Diagnostics\n* Make sure you have an API_KEY set at `~/.env`\n* Make sure you aren't being rate-limited.""" + return (_in / 1000) * in_price_ratio + (_out / 1000) * out_price_ratio - def __init__(self, api_key: str, model: str, memory: bool = True): - self.error_fallback_message = API_ERROR_FALLBACK - self.chat: list[dict[str, str]] = [] - self.model = model - self.api_key = api_key - self.memory = memory - self.messages = [] - self.error_fallback_message = LiteAssistant.API_ERROR_MESSAGE + def clear_history(self): + """ + Wipe all message history during this conversation, except for the + first system message + """ + self.messages = [self.messages[0]] self.input_tokens_this_convo = APPROX_PROMPT_TOKEN_USAGE self.output_tokens_this_convo = 10 - self.prompt, self.summary_prompt = ("", "") - - self._encoding_engine = tiktoken.get_encoding("cl100k_base") self.price_of_this_convo = self.get_default_price_of_prompt() - def set_history(self, new_history: list[dict[str, str]]): - self.clear_history() - for message in new_history: - self.messages.append(message["content"]) - - def clear_history(self) -> None: - """reset internal message queue""" - self.messages.clear() - def get_response_stream(self, user_input: str) -> Generator: - """Fake a stream output with PaLM 2""" - response = self.get_response(user_input) - for i in range(0, len(response), 8): - yield response[i : i + 8] + """ + Use OpenAI API to retrieve a ChatCompletion response from a GPT model. - def log_assistant_response(self, final_response: str) -> None: # pyright: ignore - ... - - def get_response(self, user_input: str) -> str: + Memory can be configured so that the assistant forgets previous messages + you or it has sent. (saves tokens ($$$) as well) + """ if not self.memory: self.clear_history() self.messages.append({"role": "user", "content": user_input}) @@ -57,11 +127,30 @@ def get_response(self, user_input: str) -> str: messages=self.messages, stream=True, ) + # print(type(response)) + if isinstance(response, dict): + # fake stream output if model doesn't support it + tmp_response = response + text_response = response['choices'][0]['message']['content'] + tmp_response['choices'][0]['delta'] = {"content": ""} + for i in range(0, len(text_response), 8): + tmp_response['choices'][0]['delta']['content'] = text_response[i : i + 8] + yield tmp_response - return response # type: ignore + # print(response["usage"]) + return None - def _bad_key(self) -> bool: - return not self.api_key or (type(self.api_key) is str and len(self.api_key) < 1) + # return response # type: ignore + + def get_response(self, user_input: str) -> str: + """Get an entire string back from the assistant""" + messages = [ + {"role": "system", "content": self.summary_prompt}, + {"role": "user", "content": user_input.rstrip()}, + ] + response = completion(model=self.model, messages=messages) + + return response["choices"][0]["message"]["content"] # type: ignore def get_conversation_summary(self, initial_message: str) -> str: """Generate a short 6 word or less summary of the user\'s first message""" @@ -72,12 +161,29 @@ def get_conversation_summary(self, initial_message: str) -> str: return summary + def log_assistant_response(self, final_response: str) -> None: + if not self.memory: + return + + self.messages.append({"role": "assistant", "content": final_response}) + + +def _test() -> None: + from config import PROMPT + model_keys={"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY")} + gpt = LiteAssistant(model_keys=model_keys, model="claude-instant-1", prompt=PROMPT) + + # response = gpt.get_response("How do I make carrot cake?") + # summary = gpt.get_conversation_summary("How do I get into Dirt Biking?") + # print(f"{summary = }") + response_iterator = gpt.get_response_stream("hi how are you!") + for response in response_iterator: + try: + print(response["choices"][0]["delta"]["content"], end="", flush=True) + except: + traceback.print_exc() + continue + if __name__ == "__main__": - ... - agent = LiteAssistant(model="claude-instant-1") - print(agent.get_conversation_summary("What is the third closest moon from Saturn")) - while 1: - user_inp = input("> ") - msg = agent.get_response(user_input=user_inp) - print("\n", msg) + _test()