Skip to content

Commit

Permalink
generators: update openai for o1 models (#922)
Browse files Browse the repository at this point in the history
The OpenAI "o1" reasoning models have different API expectations from
other Chat API models, in terms of params and defaults, and so a new
class is added to account for these, and some switching hints placed in
the code.
  • Loading branch information
jmartin-tech authored Sep 23, 2024
2 parents 2d25cbf + 14755c2 commit bcd5b69
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 46 deletions.
2 changes: 1 addition & 1 deletion garak/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class PluginConfigurationError(GarakException):


class BadGeneratorException(PluginConfigurationError):
"""Generator config/description is not usable"""
"""Generator invocation requested is not usable"""


class RateLimitHit(Exception):
Expand Down
123 changes: 80 additions & 43 deletions garak/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
* https://platform.openai.com/docs/model-index-for-researchers
"""

import inspect
import json
import logging
import re
Expand All @@ -24,25 +25,35 @@

# lists derived from https://platform.openai.com/docs/models
chat_models = (
"gpt-4", # links to latest version
"gpt-4-turbo", # links to latest version
"gpt-4o", # links to latest version
"gpt-4o-mini", # links to latest version
"gpt-4-turbo-preview",
"chatgpt-4o-latest", # links to latest version
"gpt-3.5-turbo", # links to latest version
"gpt-4-32k",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4-1106-vision-preview",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0613",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613", # deprecated, shutdown 2024-06-13
"gpt-3.5-turbo-16k-0613", # # deprecated, shutdown 2024-06-13
"gpt-4", # links to latest version
"gpt-4-0125-preview",
"gpt-4-0314", # legacy
"gpt-4-0613",
"gpt-4-1106-preview",
"gpt-4-1106-vision-preview",
"gpt-4-32k", # deprecated, shutdown 2025-06-06
"gpt-4-32k-0314", # deprecated, shutdown 2025-06-06
"gpt-4-32k-0613", # deprecated, shutdown 2025-06-06
"gpt-4-turbo", # links to latest version
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-vision-preview",
"gpt-4o", # links to latest version
"gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"gpt-4o-mini", # links to latest version
"gpt-4o-mini-2024-07-18",
"o1-mini", # links to latest version
"o1-mini-2024-09-12",
"o1-preview", # links to latest version
"o1-preview-2024-09-12",
# "gpt-3.5-turbo-0613", # deprecated, shutdown 2024-09-13
# "gpt-3.5-turbo-16k-0613", # # deprecated, shutdown 2024-09-13
)

completion_models = (
Expand All @@ -64,26 +75,37 @@
)

context_lengths = {
"gpt-3.5-turbo-0125": 16385,
"babbage-002": 16384,
"chatgpt-4o-latest": 128000,
"davinci-002": 16384,
"gpt-3.5-turbo": 16385,
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-1106": 16385,
"gpt-3.5-turbo-instruct": 4096,
"gpt-3.5-turbo-16k": 16385,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k-0613": 16385,
"babbage-002": 16384,
"davinci-002": 16384,
"gpt-4-turbo": 128000,
"gpt-4-turbo-2024-04-09": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-3.5-turbo-instruct": 4096,
"gpt-4": 8192,
"gpt-4-0125-preview": 128000,
"gpt-4-0314": 8192,
"gpt-4-0613": 8192,
"gpt-4-1106-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4-1106-vision-preview": 128000,
"gpt-4": 8192,
"gpt-4-0613": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
"gpt-4-turbo": 128000,
"gpt-4-turbo-2024-04-09": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4o": 128000,
"gpt-4o-2024-05-13": 128000,
"gpt-4o-2024-08-06": 128000,
"gpt-4o-mini": 16384,
"gpt-4o-mini-2024-07-18": 16384,
"o1-mini": 65536,
"o1-mini-2024-09-12": 65536,
"o1-preview": 32768,
"o1-preview-2024-09-12": 32768,
}


Expand Down Expand Up @@ -172,23 +194,15 @@ def _call_model(
# reload client once when consuming the generator
self._load_client()

create_args = {
"model": self.name,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"n": generations_this_call,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"stop": self.stop,
"seed": self.seed,
}

create_args = {
k: v
for k, v in create_args.items()
if v is not None and k not in self.suppressed_params
}
create_args = {}
if "n" not in self.suppressed_params:
create_args["n"] = generations_this_call
for arg in inspect.signature(self.generator.create).parameters:
if arg == "model":
create_args[arg] = self.name
continue
if hasattr(self, arg) and arg not in self.suppressed_params:
create_args[arg] = getattr(self, arg)

if self.generator == self.client.completions:
if not isinstance(prompt, str):
Expand Down Expand Up @@ -264,11 +278,17 @@ def _load_client(self):
r"^.+-[01][0-9][0-3][0-9]$", self.name
): # handle model names -MMDDish suffix
self.generator = self.client.completions

else:
raise ValueError(
f"No {self.generator_family_name} API defined for '{self.name}' in generators/openai.py - please add one!"
)

if self.__class__.__name__ == "OpenAIGenerator" and self.name.startswith("o1-"):
msg = "'o1'-class models should use openai.OpenAIReasoningGenerator. Try e.g. `-m openai.OpenAIReasoningGenerator` instead of `-m openai`"
logging.error(msg)
raise garak.exception.BadGeneratorException("🛑 " + msg)

def _clear_client(self):
self.generator = None
self.client = None
Expand All @@ -282,4 +302,21 @@ def __init__(self, name="", config_root=_config):
super().__init__(self.name, config_root=config_root)


class OpenAIReasoningGenerator(OpenAIGenerator):
"""Generator wrapper for OpenAI reasoning models, e.g. `o1` family."""

supports_multiple_generations = False

DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"top_p": 1.0,
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"seed": None,
"stop": ["#", ";"],
"suppressed_params": set(["n", "temperature", "max_tokens", "stop"]),
"retry_json": True,
"max_completion_tokens": 1500,
}


DEFAULT_CLASS = "OpenAIGenerator"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ dependencies = [
"colorama>=0.4.3",
"tqdm>=4.64.0",
"cohere>=4.5.1,<5",
"openai>=1.14.0,<2",
"openai>=1.45.0,<2",
"replicate>=0.8.3",
"google-api-python-client>=2.0",
"backoff>=2.1.1",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ datasets>=2.14.6,<2.17
colorama>=0.4.3
tqdm>=4.64.0
cohere>=4.5.1,<5
openai>=1.14.0,<2
openai>=1.45.0,<2
replicate>=0.8.3
google-api-python-client>=2.0
backoff>=2.1.1
Expand Down
9 changes: 9 additions & 0 deletions tests/generators/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import openai

import garak.exception
from garak.generators.openai import OpenAIGenerator


Expand Down Expand Up @@ -90,3 +91,11 @@ def test_openai_chat():
for item in output:
assert isinstance(item, str)
print("test passed!")


@pytest.mark.usefixtures("set_fake_env")
def test_reasoning_switch():
with pytest.raises(garak.exception.BadGeneratorException):
generator = OpenAIGenerator(
name="o1-mini"
) # o1 models should use ReasoningGenerator

0 comments on commit bcd5b69

Please sign in to comment.