Skip to content

Commit

Permalink
Store prompt attachments in attachments and prompt_attachments tables
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Oct 27, 2024
1 parent a1ee8ac commit 07061de
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 1 deletion.
2 changes: 2 additions & 0 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def convert(self, value, param, ctx):
path = pathlib.Path(value)
if not path.exists():
self.fail(f"File {value} does not exist", param, ctx)
path = path.resolve()
# Try to guess type
mimetype = puremagic.from_file(str(path), mime=True)
return Attachment(mimetype, str(path), None, None)
Expand All @@ -94,6 +95,7 @@ def attachment_types_callback(ctx, param, values):
path = pathlib.Path(value)
if not path.exists():
raise click.BadParameter(f"File {value} does not exist")
path = path.resolve()
attachment = Attachment(mimetype, str(path), None, None)
collected.append(attachment)
return collected
Expand Down
26 changes: 26 additions & 0 deletions llm/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,29 @@ def m010_create_new_log_tables(db):
@migration
def m011_fts_for_responses(db):
db["responses"].enable_fts(["prompt", "response"], create_triggers=True)


@migration
def m012_attachments_tables(db):
db["attachments"].create(
{
"id": str,
"type": str,
"path": str,
"url": str,
"content": bytes,
},
pk="id",
)
db["prompt_attachments"].create(
{
"response_id": str,
"attachment_id": str,
"order": int,
},
foreign_keys=(
("response_id", "responses", "id"),
("attachment_id", "attachments", "id"),
),
pk=("response_id", "attachment_id"),
)
35 changes: 34 additions & 1 deletion llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass, field
import datetime
from .errors import NeedsKeyException
import hashlib
import httpx
from itertools import islice
import puremagic
Expand All @@ -23,6 +24,17 @@ class Attachment:
url: Optional[str] = None
content: Optional[bytes] = None

def hash_id(self):
# Hash of the binary content, or of '{"url": "https://..."}' for URL attachments
if self.content:
return hashlib.sha256(self.content).hexdigest()
elif self.path:
return hashlib.sha256(open(self.path, "rb").read()).hexdigest()
else:
return hashlib.sha256(
json.dumps({"url": self.url}).encode("utf-8")
).hexdigest()

def resolve_type(self):
if self.type:
return self.type
Expand Down Expand Up @@ -178,8 +190,9 @@ def log_to_db(self, db):
},
ignore=True,
)
response_id = str(ULID()).lower()
response = {
"id": str(ULID()).lower(),
"id": response_id,
"model": self.model.model_id,
"prompt": self.prompt.prompt,
"system": self.prompt.system,
Expand All @@ -196,6 +209,26 @@ def log_to_db(self, db):
"datetime_utc": self.datetime_utc(),
}
db["responses"].insert(response)
# Persist any attachments - loop through with index
for index, attachment in enumerate(self.prompt.attachments):
attachment_id = attachment.hash_id()
db["attachments"].insert(
{
"id": attachment_id,
"type": attachment.resolve_type(),
"path": attachment.path,
"url": attachment.url,
"content": attachment.content,
},
replace=True,
)
db["prompt_attachments"].insert(
{
"response_id": response_id,
"attachment_id": attachment_id,
"order": index,
},
)

@classmethod
def fake(
Expand Down

0 comments on commit 07061de

Please sign in to comment.