Skip to content

Commit

Permalink
Made mypy happy with llm/models.py
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 7, 2024
1 parent 61dfc1d commit b3a6ec7
Showing 1 changed file with 70 additions and 108 deletions.
178 changes: 70 additions & 108 deletions llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
from pydantic import BaseModel
from ulid import ULID

ModelT = TypeVar("ModelT", bound=Union["Model", "AsyncModel"])
ConversationT = TypeVar(
"ConversationT", bound=Optional[Union["Conversation", "AsyncConversation"]]
)
ResponseT = TypeVar("ResponseT")


CONVERSATION_NAME_LENGTH = 32


Expand Down Expand Up @@ -131,7 +138,7 @@ def prompt(
system: Optional[str] = None,
stream: bool = True,
**options
):
) -> "Response":
return Response(
Prompt(
prompt,
Expand All @@ -156,10 +163,44 @@ def from_row(cls, row):
)


ModelT = TypeVar("ModelT", bound=Union["Model", "AsyncModel"])
ConversationT = TypeVar(
"ConversationT", bound=Optional[Union["Conversation", "AsyncConversation"]]
)
@dataclass
class AsyncConversation:
model: "AsyncModel"
id: str = field(default_factory=lambda: str(ULID()).lower())
name: Optional[str] = None
responses: List["AsyncResponse"] = field(default_factory=list)

def prompt(
self,
prompt: Optional[str],
*,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
stream: bool = True,
**options
) -> "AsyncResponse":
return AsyncResponse(
Prompt(
prompt,
model=self.model,
attachments=attachments,
system=system,
options=self.model.Options(**options),
),
self.model,
stream,
conversation=self,
)

@classmethod
def from_row(cls, row):
from llm import get_model

return cls(
model=get_model(row["model"]),
id=row["id"],
name=row["name"],
)


class _BaseResponse(ABC, Generic[ModelT, ConversationT]):
Expand All @@ -168,7 +209,7 @@ def __init__(
prompt: Prompt,
model: ModelT,
stream: bool,
conversation: ConversationT = None,
conversation: Optional[ConversationT] = None,
):
self.prompt = prompt
self._prompt_json = None
Expand All @@ -183,25 +224,6 @@ def __init__(
self._end: Optional[float] = None
self._start_utcnow: Optional[datetime.datetime] = None

def __str__(self) -> str:
return self.text()

def text(self) -> str:
self._force()
return "".join(self._chunks)

def json(self) -> Optional[Dict[str, Any]]:
self._force()
return self.response_json

def duration_ms(self) -> int:
self._force()
return int((self._end - self._start) * 1000)

def datetime_utc(self) -> str:
self._force()
return self._start_utcnow.isoformat()

@classmethod
def from_row(cls, db, row):
from llm import get_model
Expand Down Expand Up @@ -241,10 +263,29 @@ def from_row(cls, db, row):


class Response(_BaseResponse["Model", Optional["Conversation"]]):
def __str__(self) -> str:
return self.text()

def _force(self):
if not self._done:
list(self)

def text(self) -> str:
self._force()
return "".join(self._chunks)

def json(self) -> Optional[Dict[str, Any]]:
self._force()
return self.response_json

def duration_ms(self) -> int:
self._force()
return int(((self._end or 0) - (self._start or 0)) * 1000)

def datetime_utc(self) -> str:
self._force()
return self._start_utcnow.isoformat() if self._start_utcnow else ""

def __iter__(self) -> Iterator[str]:
self._start = time.monotonic()
self._start_utcnow = datetime.datetime.utcnow()
Expand Down Expand Up @@ -332,7 +373,7 @@ async def __aiter__(self) -> AsyncIterator[str]:
yield chunk
return

async for chunk in self.model.execute(
async for chunk in await self.model.execute(
self.prompt,
stream=self.stream,
response=self,
Expand All @@ -356,16 +397,16 @@ async def json(self) -> Optional[Dict[str, Any]]:

async def duration_ms(self) -> int:
await self._force()
return int((self._end - self._start) * 1000)
return int(((self._end or 0) - (self._start or 0)) * 1000)

async def datetime_utc(self) -> str:
await self._force()
return self._start_utcnow.isoformat()
return self._start_utcnow.isoformat() if self._start_utcnow else ""

@classmethod
def fake(
cls,
model: "Model",
model: "AsyncModel",
prompt: str,
*attachments: List[Attachment],
system: str,
Expand Down Expand Up @@ -468,10 +509,6 @@ def get_key(self):
raise NeedsKeyException(message)


ResponseT = TypeVar("ResponseT")
ConversationT = TypeVar("ConversationT")


class _BaseModel(ABC, _get_key_mixin, Generic[ResponseT, ConversationT]):
model_id: str

Expand Down Expand Up @@ -597,81 +634,6 @@ def response(self, prompt: Prompt, stream: bool = True) -> "AsyncResponse":
return AsyncResponse(prompt, self, stream)


class Model(ABC, _get_key_mixin):
model_id: str

# API key handling
key: Optional[str] = None
needs_key: Optional[str] = None
key_env_var: Optional[str] = None

# Model characteristics
can_stream: bool = False
attachment_types: Set = set()

class Options(_Options):
pass

def conversation(self):
return Conversation(model=self)

@abstractmethod
def execute(
self,
prompt: Prompt,
stream: bool,
response: Response,
conversation: Optional[Conversation],
) -> Iterator[str]:
"""
Execute a prompt and yield chunks of text, or yield a single big chunk.
Any additional useful information about the execution should be assigned to the response.
"""
pass

def prompt(
self,
prompt: str,
*,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
stream: bool = True,
**options
):
# Validate attachments
if attachments and not self.attachment_types:
raise ValueError(
"This model does not support attachments, but some were provided"
)
for attachment in attachments or []:
attachment_type = attachment.resolve_type()
if attachment_type not in self.attachment_types:
raise ValueError(
"This model does not support attachments of type '{}', only {}".format(
attachment_type, ", ".join(self.attachment_types)
)
)
return self.response(
Prompt(
prompt,
attachments=attachments,
system=system,
model=self,
options=self.Options(**options),
),
stream=stream,
)

def response(self, prompt: Prompt, stream: bool = True) -> Response:
return Response(prompt, self, stream)

def __str__(self) -> str:
return "{}: {}".format(self.__class__.__name__, self.model_id)

def __repr__(self):
return "<Model '{}'>".format(self.model_id)


class EmbeddingModel(ABC, _get_key_mixin):
model_id: str
key: Optional[str] = None
Expand Down

0 comments on commit b3a6ec7

Please sign in to comment.