Skip to content

Commit

Permalink
Refactor Chat and AsyncChat to use _Shared base class
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 6, 2024
1 parent 44e6be1 commit 2b6f5cc
Showing 1 changed file with 37 additions and 42 deletions.
79 changes: 37 additions & 42 deletions llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,11 @@ def validate_logit_bias(cls, logit_bias):
return validated_logit_bias


class Chat(Model):
class _Shared:
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"

default_max_tokens = None

def get_async_model(self):
return AsyncChat(self.model_id, self.key)

class Options(SharedOptions):
json_object: Optional[bool] = Field(
description="Output a valid JSON object {...}. Prompt must mention JSON.",
Expand Down Expand Up @@ -370,40 +366,6 @@ def build_messages(self, prompt, conversation):
messages.append({"role": "user", "content": attachment_message})
return messages

def execute(self, prompt, stream, response, conversation=None):
if prompt.system and not self.allows_system_prompt:
raise NotImplementedError("Model does not support system prompts")
messages = self.build_messages(prompt, conversation)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client()
if stream:
completion = client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=True,
**kwargs,
)
chunks = []
for chunk in completion:
chunks.append(chunk)
try:
content = chunk.choices[0].delta.content
except IndexError:
content = None
if content is not None:
yield content
response.response_json = remove_dict_none_values(combine_chunks(chunks))
else:
completion = client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=False,
**kwargs,
)
response.response_json = remove_dict_none_values(completion.model_dump())
yield completion.choices[0].message.content
response._prompt_json = redact_data_urls({"messages": messages})

def get_client(self, async_=False):
kwargs = {}
if self.api_base:
Expand Down Expand Up @@ -441,10 +403,43 @@ def build_kwargs(self, prompt, stream):
return kwargs


class AsyncChat(AsyncModel, Chat):
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"
class Chat(_Shared, Model):
def execute(self, prompt, stream, response, conversation=None):
if prompt.system and not self.allows_system_prompt:
raise NotImplementedError("Model does not support system prompts")
messages = self.build_messages(prompt, conversation)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client()
if stream:
completion = client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=True,
**kwargs,
)
chunks = []
for chunk in completion:
chunks.append(chunk)
try:
content = chunk.choices[0].delta.content
except IndexError:
content = None
if content is not None:
yield content
response.response_json = remove_dict_none_values(combine_chunks(chunks))
else:
completion = client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=False,
**kwargs,
)
response.response_json = remove_dict_none_values(completion.model_dump())
yield completion.choices[0].message.content
response._prompt_json = redact_data_urls({"messages": messages})


class AsyncChat(_Shared, AsyncModel):
async def execute(self, prompt, stream, response, conversation=None):
if prompt.system and not self.allows_system_prompt:
raise NotImplementedError("Model does not support system prompts")
Expand Down

0 comments on commit 2b6f5cc

Please sign in to comment.