Skip to content

Commit

Permalink
First WIP prototype of async mode, refs #507
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 6, 2024
1 parent fe1e097 commit e26e7f7
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 43 deletions.
6 changes: 6 additions & 0 deletions llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
NeedsKeyException,
)
from .models import (
AsyncModel,
Attachment,
Conversation,
Model,
Expand All @@ -26,6 +27,7 @@

__all__ = [
"hookimpl",
"get_async_model",
"get_model",
"get_key",
"user_dir",
Expand Down Expand Up @@ -143,6 +145,10 @@ def get_model_aliases() -> Dict[str, Model]:
return model_aliases


def get_async_model(model_id: str) -> AsyncModel:
return get_model(model_id).get_async_model()


class UnknownModelError(KeyError):
pass

Expand Down
56 changes: 45 additions & 11 deletions llm/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import click
from click_default_group import DefaultGroup
from dataclasses import asdict
Expand All @@ -11,6 +12,7 @@
Template,
UnknownModelError,
encode,
get_async_model,
get_default_model,
get_default_embedding_model,
get_embedding_models_with_aliases,
Expand Down Expand Up @@ -193,6 +195,7 @@ def cli():
)
@click.option("--key", help="API key to use")
@click.option("--save", help="Save prompt with this template name")
@click.option("async_", "--async", is_flag=True, help="Run prompt asynchronously")
def prompt(
prompt,
system,
Expand All @@ -209,6 +212,7 @@ def prompt(
conversation_id,
key,
save,
async_,
):
"""
Execute a prompt
Expand Down Expand Up @@ -325,7 +329,10 @@ def read_prompt():

# Now resolve the model
try:
model = model_aliases[model_id]
if async_:
model = get_async_model(model_id)
else:
model = get_model(model_id)
except KeyError:
raise click.ClickException("'{}' is not a known model".format(model_id))

Expand Down Expand Up @@ -363,21 +370,48 @@ def read_prompt():
prompt_method = conversation.prompt

try:
response = prompt_method(
prompt, attachments=resolved_attachments, system=system, **validated_options
)
if should_stream:
for chunk in response:
print(chunk, end="")
sys.stdout.flush()
print("")
if async_:

async def inner():
if should_stream:
async for chunk in prompt_method(
prompt,
attachments=resolved_attachments,
system=system,
**validated_options,
):
print(chunk, end="")
sys.stdout.flush()
print("")
else:
response = await prompt_method(
prompt,
attachments=resolved_attachments,
system=system,
**validated_options,
)
print(response.text())

asyncio.run(inner())
else:
print(response.text())
response = prompt_method(
prompt,
attachments=resolved_attachments,
system=system,
**validated_options,
)
if should_stream:
for chunk in response:
print(chunk, end="")
sys.stdout.flush()
print("")
else:
print(response.text())
except Exception as ex:
raise click.ClickException(str(ex))

# Log to the database
if (logs_on() or log) and not no_log:
if (logs_on() or log) and not no_log and not async_:
log_path = logs_db_path()
(log_path.parent).mkdir(parents=True, exist_ok=True)
db = sqlite_utils.Database(log_path)
Expand Down
60 changes: 54 additions & 6 deletions llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from llm import EmbeddingModel, Model, hookimpl
from llm import AsyncModel, EmbeddingModel, Model, hookimpl
import llm
from llm.utils import dicts_to_table_string, remove_dict_none_values, logging_client
import click
Expand Down Expand Up @@ -254,6 +254,9 @@ class Chat(Model):

default_max_tokens = None

def get_async_model(self):
return AsyncChat(self.model_id, self.key)

class Options(SharedOptions):
json_object: Optional[bool] = Field(
description="Output a valid JSON object {...}. Prompt must mention JSON.",
Expand Down Expand Up @@ -297,10 +300,8 @@ def __init__(
def __str__(self):
return "OpenAI Chat: {}".format(self.model_id)

def execute(self, prompt, stream, response, conversation=None):
def build_messages(self, prompt, conversation):
messages = []
if prompt.system and not self.allows_system_prompt:
raise NotImplementedError("Model does not support system prompts")
current_system = None
if conversation is not None:
for prev_response in conversation.responses:
Expand Down Expand Up @@ -345,7 +346,12 @@ def execute(self, prompt, stream, response, conversation=None):
{"type": "image_url", "image_url": {"url": url}}
)
messages.append({"role": "user", "content": attachment_message})
return messages

def execute(self, prompt, stream, response, conversation=None):
if prompt.system and not self.allows_system_prompt:
raise NotImplementedError("Model does not support system prompts")
messages = self.build_messages(prompt, conversation)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client()
if stream:
Expand Down Expand Up @@ -376,7 +382,7 @@ def execute(self, prompt, stream, response, conversation=None):
yield completion.choices[0].message.content
response._prompt_json = redact_data_urls({"messages": messages})

def get_client(self):
def get_client(self, async_=False):
kwargs = {}
if self.api_base:
kwargs["base_url"] = self.api_base
Expand All @@ -396,7 +402,10 @@ def get_client(self):
kwargs["default_headers"] = self.headers
if os.environ.get("LLM_OPENAI_SHOW_RESPONSES"):
kwargs["http_client"] = logging_client()
return openai.OpenAI(**kwargs)
if async_:
return openai.AsyncOpenAI(**kwargs)
else:
return openai.OpenAI(**kwargs)

def build_kwargs(self, prompt, stream):
kwargs = dict(not_nulls(prompt.options))
Expand All @@ -410,6 +419,45 @@ def build_kwargs(self, prompt, stream):
return kwargs


class AsyncChat(AsyncModel, Chat):
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"

async def execute(self, prompt, stream, response, conversation=None):
if prompt.system and not self.allows_system_prompt:
raise NotImplementedError("Model does not support system prompts")
messages = self.build_messages(prompt, conversation)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client(async_=True)
if stream:
completion = await client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=True,
**kwargs,
)
chunks = []
async for chunk in completion:
chunks.append(chunk)
try:
content = chunk.choices[0].delta.content
except IndexError:
content = None
if content is not None:
yield content
response.response_json = remove_dict_none_values(combine_chunks(chunks))
else:
completion = await client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=False,
**kwargs,
)
response.response_json = remove_dict_none_values(completion.model_dump())
yield completion.choices[0].message.content
response._prompt_json = redact_data_urls({"messages": messages})


class Completion(Chat):
class Options(SharedOptions):
logprobs: Optional[int] = Field(
Expand Down
Loading

0 comments on commit e26e7f7

Please sign in to comment.