diff --git a/poetry.lock b/poetry.lock index b68a58e..bab04ff 100644 --- a/poetry.lock +++ b/poetry.lock @@ -602,6 +602,23 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "click-default-group" +version = "1.2.4" +description = "click_default_group" +optional = true +python-versions = ">=2.7" +files = [ + {file = "click_default_group-1.2.4-py2.py3-none-any.whl", hash = "sha256:9b60486923720e7fc61731bdb32b617039aba820e22e1c88766b1125592eaa5f"}, + {file = "click_default_group-1.2.4.tar.gz", hash = "sha256:eb3f3c99ec0d456ca6cd2a7f08f7d4e91771bef51b01bdd9580cc6450fe1251e"}, +] + +[package.dependencies] +click = "*" + +[package.extras] +test = ["pytest"] + [[package]] name = "colorama" version = "0.4.6" @@ -2032,6 +2049,34 @@ files = [ httpx = ">=0.20.0" pydantic = ">=1.10" +[[package]] +name = "llm" +version = "0.14" +description = "A CLI utility and Python library for interacting with Large Language Models, including OpenAI, PaLM and local models installed on your own machine." +optional = true +python-versions = ">=3.8" +files = [ + {file = "llm-0.14-py3-none-any.whl", hash = "sha256:bb3bd50a1ff47d5acc43709b0d01394e2384cd91e291a68500b00ad154831943"}, + {file = "llm-0.14.tar.gz", hash = "sha256:c4150eb47246342846cf57497304ed9afb918a94c5a39049e7422f1961371c09"}, +] + +[package.dependencies] +click = "*" +click-default-group = ">=1.2.3" +openai = ">=1.0" +pip = "*" +pluggy = "*" +pydantic = ">=1.10.2" +pyreadline3 = {version = "*", markers = "sys_platform == \"win32\""} +python-ulid = "*" +PyYAML = "*" +setuptools = "*" +sqlite-migrate = ">=0.1a2" +sqlite-utils = ">=3.35.0" + +[package.extras] +test = ["black (>=24.1.0)", "cogapp", "mypy (>=1.10.0)", "numpy", "pytest", "pytest-httpx", "ruff", "types-PyYAML", "types-click", "types-setuptools"] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -2947,6 +2992,17 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa typing = ["typing-extensions"] xmp = ["defusedxml"] +[[package]] +name = "pip" +version = "24.0" +description = "The PyPA recommended tool for installing Python packages." +optional = true +python-versions = ">=3.7" +files = [ + {file = "pip-24.0-py3-none-any.whl", hash = "sha256:ba0d021a166865d2265246961bec0152ff124de910c5cc39f1156ce3fa7c69dc"}, + {file = "pip-24.0.tar.gz", hash = "sha256:ea9bd1a847e8c5774a5777bb398c19e80bcd4e2aa16a4b301b718fe6f593aba2"}, +] + [[package]] name = "pkginfo" version = "1.11.1" @@ -3395,6 +3451,17 @@ files = [ {file = "pyproject_hooks-1.1.0.tar.gz", hash = "sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965"}, ] +[[package]] +name = "pyreadline3" +version = "3.4.1" +description = "A python implementation of GNU readline." +optional = true +python-versions = "*" +files = [ + {file = "pyreadline3-3.4.1-py3-none-any.whl", hash = "sha256:b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb"}, + {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"}, +] + [[package]] name = "pytest" version = "8.2.2" @@ -3447,6 +3514,20 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-ulid" +version = "2.6.0" +description = "Universally unique lexicographically sortable identifier" +optional = true +python-versions = ">=3.9" +files = [ + {file = "python_ulid-2.6.0-py3-none-any.whl", hash = "sha256:b47cc7a427b82f7526af96385d7702685df808e9b4922523dd5988a3ba98a89d"}, + {file = "python_ulid-2.6.0.tar.gz", hash = "sha256:904e19093dd6578a5ce01a8274e3e228d556d47be3bda328da2d3601c5240c4f"}, +] + +[package.extras] +pydantic = ["pydantic (>=2.0)"] + [[package]] name = "pytz" version = "2024.1" @@ -4057,6 +4138,21 @@ transformers = ">=4.34.0,<5.0.0" [package.extras] dev = ["pre-commit", "pytest", "ruff (>=0.3.0)"] +[[package]] +name = "setuptools" +version = "70.0.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +optional = true +python-versions = ">=3.8" +files = [ + {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, + {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] + [[package]] name = "shellingham" version = "1.5.4" @@ -4213,6 +4309,63 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "sqlite-fts4" +version = "1.0.3" +description = "Python functions for working with SQLite FTS4 search" +optional = true +python-versions = "*" +files = [ + {file = "sqlite-fts4-1.0.3.tar.gz", hash = "sha256:78b05eeaf6680e9dbed8986bde011e9c086a06cb0c931b3cf7da94c214e8930c"}, + {file = "sqlite_fts4-1.0.3-py3-none-any.whl", hash = "sha256:0359edd8dea6fd73c848989e1e2b1f31a50fe5f9d7272299ff0e8dbaa62d035f"}, +] + +[package.extras] +test = ["pytest"] + +[[package]] +name = "sqlite-migrate" +version = "0.1b0" +description = "A simple database migration system for SQLite, based on sqlite-utils" +optional = true +python-versions = "*" +files = [ + {file = "sqlite-migrate-0.1b0.tar.gz", hash = "sha256:8d502b3ca4b9c45e56012bd35c03d23235f0823c976d4ce940cbb40e33087ded"}, + {file = "sqlite_migrate-0.1b0-py3-none-any.whl", hash = "sha256:a4125e35e1de3dc56b6b6ec60e9833ce0ce20192b929ddcb2d4246c5098859c6"}, +] + +[package.dependencies] +sqlite-utils = "*" + +[package.extras] +test = ["black", "mypy", "pytest", "ruff"] + +[[package]] +name = "sqlite-utils" +version = "3.36" +description = "CLI tool and Python library for manipulating SQLite databases" +optional = true +python-versions = ">=3.7" +files = [ + {file = "sqlite-utils-3.36.tar.gz", hash = "sha256:dcc311394fe86dc16f65037b0075e238efcfd2e12e65d53ed196954502996f3c"}, + {file = "sqlite_utils-3.36-py3-none-any.whl", hash = "sha256:b71e829755c2efbdcd6931a31968dee4e8bd71b3c14f0fe648b22377027c5bec"}, +] + +[package.dependencies] +click = "*" +click-default-group = ">=1.2.3" +pluggy = "*" +python-dateutil = "*" +sqlite-fts4 = "*" +tabulate = "*" + +[package.extras] +docs = ["beanbag-docutils (>=2.0)", "codespell", "furo", "pygments-csv-lexer", "sphinx-autobuild", "sphinx-copybutton"] +flake8 = ["flake8"] +mypy = ["data-science-types", "mypy", "types-click", "types-pluggy", "types-python-dateutil", "types-tabulate"] +test = ["black", "cogapp", "hypothesis", "pytest"] +tui = ["trogon"] + [[package]] name = "stack-data" version = "0.6.3" @@ -4274,6 +4427,20 @@ files = [ [package.dependencies] mpmath = ">=1.1.0,<1.4.0" +[[package]] +name = "tabulate" +version = "0.9.0" +description = "Pretty-print tabular data" +optional = true +python-versions = ">=3.7" +files = [ + {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, + {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, +] + +[package.extras] +widechars = ["wcwidth"] + [[package]] name = "tbb" version = "2021.12.0" @@ -5093,8 +5260,9 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [extras] api-bot = ["fastapi", "pydantic", "requests", "uvicorn"] azure = ["azure-storage-file-share", "pulumi", "pulumi-azure-native"] +llm-plugin = ["llm"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "4d7330aee8ed44e7b403e2a21fc304c310e2ae5f9fba6a72079257869506a085" +content-hash = "681b846e7164e161de73ba842b75f1e2b56b2de708698e0c52e236f0af5c0c9c" diff --git a/pyproject.toml b/pyproject.toml index ff73adc..a3348bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ llama-index-readers-github = "^0.1.9" tiktoken = "^0.7.0" typer = "^0.12.3" rich = "^13.7.1" - +llm = { version="^0.14", optional=true } [tool.poetry.group.dev.dependencies] black = "^24.3.0" @@ -74,6 +74,9 @@ azure = [ "pulumi-azure-native", "azure-storage-file-share", ] +llm_plugin = [ + "llm", +] [tool.poetry.scripts] reginald = "reginald.cli:cli" @@ -82,6 +85,9 @@ reginald = "reginald.cli:cli" name = "PyPI" priority = "primary" +[tool.poetry.plugins."llm"] +reginald = "reginald.llm_reginald.llm_reginald" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/reginald/llm_reginald/README.md b/reginald/llm_reginald/README.md new file mode 100644 index 0000000..127f7b9 --- /dev/null +++ b/reginald/llm_reginald/README.md @@ -0,0 +1,19 @@ +# llm-reginald + +This module contains an [LLM](https://llm.datasette.io/) plugin for interacting with [Reginald](https://github.com/alan-turing-institute/reginald). + +The plugin is available to use immediately after reginald is installed, as long as the `llm_plugin` extra dependencies are selected when installing the package (either with `pip install .[llm_plugin]` or `poetry install --extras llm_plugin` or `poetry install --all-extras`). + +The plugin requires a running Reginald server to work. By default this is assumed to be at `http://localhost:8000`. Refer to the package README for instructions. + +Assuming [LLM](http://llm.datasette.io/) is installed, you can interact with Reginald like this: + +``` +llm -m reginald 'Who are you?' +``` + +Using a different server URL: + +``` +llm -m reginald -o server_url ... +``` diff --git a/reginald/llm_reginald/__init__.py b/reginald/llm_reginald/__init__.py new file mode 100644 index 0000000..df0f741 --- /dev/null +++ b/reginald/llm_reginald/__init__.py @@ -0,0 +1 @@ +from reginald.llm_reginald import * diff --git a/reginald/llm_reginald/llm_reginald.py b/reginald/llm_reginald/llm_reginald.py new file mode 100644 index 0000000..f6d9ad0 --- /dev/null +++ b/reginald/llm_reginald/llm_reginald.py @@ -0,0 +1,60 @@ +import uuid +from typing import Optional +from urllib.parse import urljoin + +import httpx +import llm +from pydantic import Field + + +@llm.hookimpl +def register_models(register): + register(Reginald()) + + +class Reginald(llm.Model): + model_id = "reginald" + can_stream = False + + class Options(llm.Options): + server_url: str = Field( + default="http://localhost:8000", + title="Server URL", + description="The base URL of the Reginald server", + ) + + def direct_message_endpoint(self): + return urljoin(self.server_url, "direct_message") + + def execute(self, prompt, stream, response, conversation): + + message = prompt.prompt + + # Reginald keeps a separate conversation history for each + # "user_id". If 'conversation' is None (or doesn't have the + # user_id logged) start a new conversation by minting a new + # user_id history. Otherwise, extract the logged user_id to + # continue that conversation. + + try: + user_id = conversation.responses[0].response_json["user_id"] + except (TypeError, AttributeError, KeyError, IndexError) as e: + user_id = str(uuid.uuid4().int) + + try: + with httpx.Client() as client: + reginald_reply = client.post( + prompt.options.direct_message_endpoint(), + json={"message": message, "user_id": user_id}, + timeout=None, + ) + reginald_reply.raise_for_status() + except httpx.HTTPError as e: + # re-raise as an llm.ModelError for llm to report + raise llm.ModelError( + f"Could not connect to Reginald at {prompt.options.direct_message_endpoint()}.\n\nThe error was:\n {e}.\n\nIs the model server running?" + ) + + yield reginald_reply.json()["message"] + + response.response_json = {"user_id": user_id}