diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 2559fd2b..a4ebadcc 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -299,16 +299,6 @@ def _attachment(attachment): class _Shared: - needs_key = "openai" - key_env_var = "OPENAI_API_KEY" - default_max_tokens = None - - class Options(SharedOptions): - json_object: Optional[bool] = Field( - description="Output a valid JSON object {...}. Prompt must mention JSON.", - default=None, - ) - def __init__( self, model_id, @@ -437,6 +427,16 @@ def build_kwargs(self, prompt, stream): class Chat(_Shared, Model): + needs_key = "openai" + key_env_var = "OPENAI_API_KEY" + default_max_tokens = None + + class Options(SharedOptions): + json_object: Optional[bool] = Field( + description="Output a valid JSON object {...}. Prompt must mention JSON.", + default=None, + ) + def execute(self, prompt, stream, response, conversation=None): if prompt.system and not self.allows_system_prompt: raise NotImplementedError("Model does not support system prompts") @@ -473,6 +473,16 @@ def execute(self, prompt, stream, response, conversation=None): class AsyncChat(_Shared, AsyncModel): + needs_key = "openai" + key_env_var = "OPENAI_API_KEY" + default_max_tokens = None + + class Options(SharedOptions): + json_object: Optional[bool] = Field( + description="Output a valid JSON object {...}. Prompt must mention JSON.", + default=None, + ) + async def execute(self, prompt, stream, response, conversation=None): if prompt.system and not self.allows_system_prompt: raise NotImplementedError("Model does not support system prompts")