Skip to content

Commit

Permalink
generator: promote OpenAICompatible as first class generator (#1021)
Browse files Browse the repository at this point in the history
fix #1008 

Various generators extend this base template, usage and feedback suggest
UX and usability improvements will be gained by promoting this class to
be a `Configurable` generic OpenAI client based generator.

## Verification

List the steps needed to make sure this thing works

- [ ] Test against via mimic of `nim` configuration via generator config
options:
compatible.json:
``` json
{
    "openai": {
        "OpenAICompatible": {
            "uri": "https://integrate.api.nvidia.com/v1/",
            "suppressed_params": ["n", "frequency_penalty", "presence_penalty"]
        }
    }
}
```
``` bash
python -m garak -m openai.OpenAICompatible -n meta/llama3-8b-instruct -p lmrc --generator_option_file compatible.json
```
- [ ] **Verify** at least one additional modified generator `groq` /
`openai.OpenAIGenerator` / `nim` with `--parallel_attempts` enabled.
  • Loading branch information
leondz authored Dec 6, 2024
2 parents 9b33870 + 5befbbe commit 7499da1
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 37 deletions.
37 changes: 21 additions & 16 deletions garak/generators/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,23 @@
import os
import openai

from garak.generators.openai import OpenAICompatible, chat_models, completion_models, context_lengths
from garak.generators.openai import (
OpenAICompatible,
chat_models,
completion_models,
context_lengths,
)

# lists derived from https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
# some azure openai model names should be mapped to openai names
openai_model_mapping = {
"gpt-4": "gpt-4-turbo-2024-04-09",
"gpt-35-turbo": "gpt-3.5-turbo-0125",
"gpt-35-turbo-16k": "gpt-3.5-turbo-16k",
"gpt-35-turbo-instruct": "gpt-3.5-turbo-instruct"
"gpt-4": "gpt-4-turbo-2024-04-09",
"gpt-35-turbo": "gpt-3.5-turbo-0125",
"gpt-35-turbo-16k": "gpt-3.5-turbo-16k",
"gpt-35-turbo-instruct": "gpt-3.5-turbo-instruct",
}


class AzureOpenAIGenerator(OpenAICompatible):
"""Wrapper for Azure Open AI. Expects AZURE_API_KEY, AZURE_ENDPOINT and AZURE_MODEL_NAME environment variables.
Expand All @@ -31,7 +37,7 @@ class AzureOpenAIGenerator(OpenAICompatible):
To get started with this generator:
#. Visit [https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models](https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models) and find the LLM you'd like to use.
#. [Deploy a model](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource?pivots=web-portal#deploy-a-model) and copy paste the model and deployment names.
#. On the Azure portal page for the Azure OpenAI you want to use click on "Resource Management -> Keys and Endpoint" and copy paste the API Key and endpoint.
#. On the Azure portal page for the Azure OpenAI you want to use click on "Resource Management -> Keys and Endpoint" and copy paste the API Key and endpoint.
#. In your console, Set the ``AZURE_API_KEY``, ``AZURE_ENDPOINT`` and ``AZURE_MODEL_NAME`` variables.
#. Run garak, setting ``--model_type`` to ``azure`` and ``--model_name`` to the name **of the deployment**.
- e.g. ``gpt-4o``.
Expand All @@ -44,7 +50,7 @@ class AzureOpenAIGenerator(OpenAICompatible):
active = True
generator_family_name = "Azure"
api_version = "2024-06-01"

DEFAULT_PARAMS = OpenAICompatible.DEFAULT_PARAMS | {
"model_name": None,
"uri": None,
Expand All @@ -54,23 +60,23 @@ def _validate_env_var(self):
if self.model_name is None:
if not hasattr(self, "model_name_env_var"):
self.model_name_env_var = self.MODEL_NAME_ENV_VAR

self.model_name = os.getenv(self.model_name_env_var, None)

if self.model_name is None:
raise ValueError(
f'The {self.MODEL_NAME_ENV_VAR} environment variable is required.\n'
f"The {self.MODEL_NAME_ENV_VAR} environment variable is required.\n"
)

if self.uri is None:
if not hasattr(self, "endpoint_env_var"):
self.endpoint_env_var = self.ENDPOINT_ENV_VAR

self.uri = os.getenv(self.endpoint_env_var, None)

if self.uri is None:
raise ValueError(
f'The {self.ENDPOINT_ENV_VAR} environment variable is required.\n'
f"The {self.ENDPOINT_ENV_VAR} environment variable is required.\n"
)

return super()._validate_env_var()
Expand All @@ -79,7 +85,9 @@ def _load_client(self):
if self.model_name in openai_model_mapping:
self.model_name = openai_model_mapping[self.model_name]

self.client = openai.AzureOpenAI(azure_endpoint=self.uri, api_key=self.api_key, api_version=self.api_version)
self.client = openai.AzureOpenAI(
azure_endpoint=self.uri, api_key=self.api_key, api_version=self.api_version
)

if self.name == "":
raise ValueError(
Expand All @@ -102,8 +110,5 @@ def _load_client(self):
if self.model_name in context_lengths:
self.context_len = context_lengths[self.model_name]

def _clear_client(self):
self.generator = None
self.client = None

DEFAULT_CLASS = "AzureOpenAIGenerator"
4 changes: 0 additions & 4 deletions garak/generators/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ def _load_client(self):
)
self.generator = self.client.chat.completions

def _clear_client(self):
self.generator = None
self.client = None

def _call_model(
self, prompt: str | List[dict], generations_this_call: int = 1
) -> List[Union[str, None]]:
Expand Down
4 changes: 0 additions & 4 deletions garak/generators/nim.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ def _load_client(self):
)
self.generator = self.client.chat.completions

def _clear_client(self):
self.generator = None
self.client = None

def _prepare_prompt(self, prompt):
return prompt

Expand Down
27 changes: 17 additions & 10 deletions garak/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,15 @@ class OpenAICompatible(Generator):

ENV_VAR = "OpenAICompatible_API_KEY".upper() # Placeholder override when extending

active = False # this interface class is not active
active = True
supports_multiple_generations = True
generator_family_name = "OpenAICompatible" # Placeholder override when extending

# template defaults optionally override when extending
DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"temperature": 0.7,
"top_p": 1.0,
"uri": "http://localhost:8000/v1/",
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"seed": None,
Expand All @@ -141,13 +142,18 @@ def __setstate__(self, d) -> object:
self._load_client()

def _load_client(self):
# Required stub implemented when extending `OpenAICompatible`
# should populate self.generator with an openai api compliant object
raise NotImplementedError
# When extending `OpenAICompatible` this method is a likely location for target application specific
# customization and must populate self.generator with an openai api compliant object
self.client = openai.OpenAI(base_url=self.uri, api_key=self.api_key)
if self.name in ("", None):
raise ValueError(
f"{self.generator_family_name} requires model name to be set, e.g. --model_name org/private-model-name"
)
self.generator = self.client.chat.completions

def _clear_client(self):
# Required stub implemented when extending `OpenAICompatible`
raise NotImplementedError
self.generator = None
self.client = None

def _validate_config(self):
pass
Expand Down Expand Up @@ -257,6 +263,11 @@ class OpenAIGenerator(OpenAICompatible):
active = True
generator_family_name = "OpenAI"

# remove uri as it is not overridable in this class.
DEFAULT_PARAMS = {
k: val for k, val in OpenAICompatible.DEFAULT_PARAMS.items() if k != "uri"
}

def _load_client(self):
self.client = openai.OpenAI(api_key=self.api_key)

Expand Down Expand Up @@ -289,10 +300,6 @@ def _load_client(self):
logging.error(msg)
raise garak.exception.BadGeneratorException("🛑 " + msg)

def _clear_client(self):
self.generator = None
self.client = None

def __init__(self, name="", config_root=_config):
self.name = name
self._load_config(config_root)
Expand Down
9 changes: 6 additions & 3 deletions tests/generators/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,12 @@ def test_parallel_requests():
result = g.generate(prompt="this is a test", generations_this_call=3)
assert isinstance(result, list), "Generator generate() should return a list"
assert len(result) == 3, "Generator should return 3 results as requested"
assert all(isinstance(item, str) for item in result), "All items in the generate result should be strings"
assert all(len(item) > 0 for item in result), "All generated strings should be non-empty"
assert all(
isinstance(item, str) for item in result
), "All items in the generate result should be strings"
assert all(
len(item) > 0 for item in result
), "All generated strings should be non-empty"


@pytest.mark.parametrize("classname", GENERATORS)
Expand Down Expand Up @@ -190,7 +194,6 @@ def test_generator_structure(classname):
"generators.huggingface.OptimumPipeline", # model name restrictions and cuda required
"generators.huggingface.Pipeline", # model name restrictions
"generators.langchain.LangChainLLMGenerator", # model name restrictions
"generators.openai.OpenAICompatible", # template class not intended to ever be `Active`
]
]

Expand Down

0 comments on commit 7499da1

Please sign in to comment.