Skip to content

Commit

Permalink
Black and mypy and ruff all happy
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 13, 2024
1 parent 145b5cd commit 8ab5ea3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 21 deletions.
4 changes: 1 addition & 3 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)

from .migrations import migrate
from .plugins import pm
from .plugins import pm, load_plugins
import base64
import httpx
import pathlib
Expand Down Expand Up @@ -1817,8 +1817,6 @@ def render_errors(errors):
return "\n".join(output)


from .plugins import load_plugins

load_plugins()

pm.hook.register_commands(cli=cli)
Expand Down
6 changes: 4 additions & 2 deletions llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic.fields import Field
from pydantic.class_validators import validator as field_validator # type: ignore [no-redef]

from typing import List, Iterable, Iterator, Optional, Union
from typing import AsyncGenerator, List, Iterable, Iterator, Optional, Union
import json
import yaml

Expand Down Expand Up @@ -483,7 +483,9 @@ class Options(SharedOptions):
default=None,
)

async def execute(self, prompt, stream, response, conversation=None):
async def execute(
self, prompt, stream, response, conversation=None
) -> AsyncGenerator[str, None]:
if prompt.system and not self.allows_system_prompt:
raise NotImplementedError("Model does not support system prompts")
messages = self.build_messages(prompt, conversation)
Expand Down
31 changes: 15 additions & 16 deletions llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from typing import (
Any,
AsyncIterator,
AsyncGenerator,
Dict,
Generic,
Iterable,
Expand Down Expand Up @@ -380,23 +380,22 @@ async def __anext__(self) -> str:
if not self._chunks:
raise StopAsyncIteration
return chunk
try:
iterator = self.model.execute(

# Get and store the generator if we don't have it yet
if not hasattr(self, "_generator"):
generator = self.model.execute(
self.prompt,
stream=self.stream,
response=self,
conversation=self.conversation,
)
async for chunk in iterator:
self._chunks.append(chunk)
return chunk

if self.conversation:
self.conversation.responses.append(self)
self._end = time.monotonic()
self._done = True
self._generator = generator

raise StopAsyncIteration
# Use the generator
try:
chunk = await self._generator.__anext__()
self._chunks.append(chunk)
return chunk
except StopAsyncIteration:
if self.conversation:
self.conversation.responses.append(self)
Expand Down Expand Up @@ -619,12 +618,12 @@ async def execute(
stream: bool,
response: "AsyncResponse",
conversation: Optional["AsyncConversation"],
) -> AsyncIterator[str]:
) -> AsyncGenerator[str, None]:
"""
Execute a prompt and yield chunks of text, or yield a single big chunk.
Any additional useful information about the execution should be assigned to the response.
Returns an async generator that executes the prompt and yields chunks of text,
or yields a single big chunk.
"""
pass
yield ""

def prompt(
self,
Expand Down

0 comments on commit 8ab5ea3

Please sign in to comment.