diff --git a/llm/cli.py b/llm/cli.py index 33e14f09..454e56ca 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -76,6 +76,7 @@ def convert(self, value, param, ctx): path = pathlib.Path(value) if not path.exists(): self.fail(f"File {value} does not exist", param, ctx) + path = path.resolve() # Try to guess type mimetype = puremagic.from_file(str(path), mime=True) return Attachment(mimetype, str(path), None, None) @@ -94,6 +95,7 @@ def attachment_types_callback(ctx, param, values): path = pathlib.Path(value) if not path.exists(): raise click.BadParameter(f"File {value} does not exist") + path = path.resolve() attachment = Attachment(mimetype, str(path), None, None) collected.append(attachment) return collected diff --git a/llm/migrations.py b/llm/migrations.py index 008ae976..91da6429 100644 --- a/llm/migrations.py +++ b/llm/migrations.py @@ -201,3 +201,29 @@ def m010_create_new_log_tables(db): @migration def m011_fts_for_responses(db): db["responses"].enable_fts(["prompt", "response"], create_triggers=True) + + +@migration +def m012_attachments_tables(db): + db["attachments"].create( + { + "id": str, + "type": str, + "path": str, + "url": str, + "content": bytes, + }, + pk="id", + ) + db["prompt_attachments"].create( + { + "response_id": str, + "attachment_id": str, + "order": int, + }, + foreign_keys=( + ("response_id", "responses", "id"), + ("attachment_id", "attachments", "id"), + ), + pk=("response_id", "attachment_id"), + ) diff --git a/llm/models.py b/llm/models.py index 77bdb8e9..3b7d4dad 100644 --- a/llm/models.py +++ b/llm/models.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field import datetime from .errors import NeedsKeyException +import hashlib import httpx from itertools import islice import puremagic @@ -23,6 +24,17 @@ class Attachment: url: Optional[str] = None content: Optional[bytes] = None + def hash_id(self): + # Hash of the binary content, or of '{"url": "https://..."}' for URL attachments + if self.content: + return hashlib.sha256(self.content).hexdigest() + elif self.path: + return hashlib.sha256(open(self.path, "rb").read()).hexdigest() + else: + return hashlib.sha256( + json.dumps({"url": self.url}).encode("utf-8") + ).hexdigest() + def resolve_type(self): if self.type: return self.type @@ -178,8 +190,9 @@ def log_to_db(self, db): }, ignore=True, ) + response_id = str(ULID()).lower() response = { - "id": str(ULID()).lower(), + "id": response_id, "model": self.model.model_id, "prompt": self.prompt.prompt, "system": self.prompt.system, @@ -196,6 +209,26 @@ def log_to_db(self, db): "datetime_utc": self.datetime_utc(), } db["responses"].insert(response) + # Persist any attachments - loop through with index + for index, attachment in enumerate(self.prompt.attachments): + attachment_id = attachment.hash_id() + db["attachments"].insert( + { + "id": attachment_id, + "type": attachment.resolve_type(), + "path": attachment.path, + "url": attachment.url, + "content": attachment.content, + }, + replace=True, + ) + db["prompt_attachments"].insert( + { + "response_id": response_id, + "attachment_id": attachment_id, + "order": index, + }, + ) @classmethod def fake(