diff --git a/garak/generators/azure.py b/garak/generators/azure.py index f355fa7f..503f176c 100644 --- a/garak/generators/azure.py +++ b/garak/generators/azure.py @@ -11,17 +11,23 @@ import os import openai -from garak.generators.openai import OpenAICompatible, chat_models, completion_models, context_lengths +from garak.generators.openai import ( + OpenAICompatible, + chat_models, + completion_models, + context_lengths, +) # lists derived from https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models # some azure openai model names should be mapped to openai names openai_model_mapping = { - "gpt-4": "gpt-4-turbo-2024-04-09", - "gpt-35-turbo": "gpt-3.5-turbo-0125", - "gpt-35-turbo-16k": "gpt-3.5-turbo-16k", - "gpt-35-turbo-instruct": "gpt-3.5-turbo-instruct" + "gpt-4": "gpt-4-turbo-2024-04-09", + "gpt-35-turbo": "gpt-3.5-turbo-0125", + "gpt-35-turbo-16k": "gpt-3.5-turbo-16k", + "gpt-35-turbo-instruct": "gpt-3.5-turbo-instruct", } + class AzureOpenAIGenerator(OpenAICompatible): """Wrapper for Azure Open AI. Expects AZURE_API_KEY, AZURE_ENDPOINT and AZURE_MODEL_NAME environment variables. @@ -31,7 +37,7 @@ class AzureOpenAIGenerator(OpenAICompatible): To get started with this generator: #. Visit [https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models](https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models) and find the LLM you'd like to use. #. [Deploy a model](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource?pivots=web-portal#deploy-a-model) and copy paste the model and deployment names. - #. On the Azure portal page for the Azure OpenAI you want to use click on "Resource Management -> Keys and Endpoint" and copy paste the API Key and endpoint. + #. On the Azure portal page for the Azure OpenAI you want to use click on "Resource Management -> Keys and Endpoint" and copy paste the API Key and endpoint. #. In your console, Set the ``AZURE_API_KEY``, ``AZURE_ENDPOINT`` and ``AZURE_MODEL_NAME`` variables. #. Run garak, setting ``--model_type`` to ``azure`` and ``--model_name`` to the name **of the deployment**. - e.g. ``gpt-4o``. @@ -44,7 +50,7 @@ class AzureOpenAIGenerator(OpenAICompatible): active = True generator_family_name = "Azure" api_version = "2024-06-01" - + DEFAULT_PARAMS = OpenAICompatible.DEFAULT_PARAMS | { "model_name": None, "uri": None, @@ -54,23 +60,23 @@ def _validate_env_var(self): if self.model_name is None: if not hasattr(self, "model_name_env_var"): self.model_name_env_var = self.MODEL_NAME_ENV_VAR - + self.model_name = os.getenv(self.model_name_env_var, None) if self.model_name is None: raise ValueError( - f'The {self.MODEL_NAME_ENV_VAR} environment variable is required.\n' + f"The {self.MODEL_NAME_ENV_VAR} environment variable is required.\n" ) - + if self.uri is None: if not hasattr(self, "endpoint_env_var"): self.endpoint_env_var = self.ENDPOINT_ENV_VAR - + self.uri = os.getenv(self.endpoint_env_var, None) if self.uri is None: raise ValueError( - f'The {self.ENDPOINT_ENV_VAR} environment variable is required.\n' + f"The {self.ENDPOINT_ENV_VAR} environment variable is required.\n" ) return super()._validate_env_var() @@ -79,7 +85,9 @@ def _load_client(self): if self.model_name in openai_model_mapping: self.model_name = openai_model_mapping[self.model_name] - self.client = openai.AzureOpenAI(azure_endpoint=self.uri, api_key=self.api_key, api_version=self.api_version) + self.client = openai.AzureOpenAI( + azure_endpoint=self.uri, api_key=self.api_key, api_version=self.api_version + ) if self.name == "": raise ValueError( @@ -102,8 +110,5 @@ def _load_client(self): if self.model_name in context_lengths: self.context_len = context_lengths[self.model_name] - def _clear_client(self): - self.generator = None - self.client = None DEFAULT_CLASS = "AzureOpenAIGenerator" diff --git a/garak/generators/groq.py b/garak/generators/groq.py index 28635965..6b7ae14d 100644 --- a/garak/generators/groq.py +++ b/garak/generators/groq.py @@ -49,10 +49,6 @@ def _load_client(self): ) self.generator = self.client.chat.completions - def _clear_client(self): - self.generator = None - self.client = None - def _call_model( self, prompt: str | List[dict], generations_this_call: int = 1 ) -> List[Union[str, None]]: diff --git a/garak/generators/nim.py b/garak/generators/nim.py index 0379aab2..19298556 100644 --- a/garak/generators/nim.py +++ b/garak/generators/nim.py @@ -63,10 +63,6 @@ def _load_client(self): ) self.generator = self.client.chat.completions - def _clear_client(self): - self.generator = None - self.client = None - def _prepare_prompt(self, prompt): return prompt diff --git a/garak/generators/openai.py b/garak/generators/openai.py index 5c27d1db..41c2ab79 100644 --- a/garak/generators/openai.py +++ b/garak/generators/openai.py @@ -114,7 +114,7 @@ class OpenAICompatible(Generator): ENV_VAR = "OpenAICompatible_API_KEY".upper() # Placeholder override when extending - active = False # this interface class is not active + active = True supports_multiple_generations = True generator_family_name = "OpenAICompatible" # Placeholder override when extending @@ -122,6 +122,7 @@ class OpenAICompatible(Generator): DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | { "temperature": 0.7, "top_p": 1.0, + "uri": "http://localhost:8000/v1/", "frequency_penalty": 0.0, "presence_penalty": 0.0, "seed": None, @@ -141,13 +142,18 @@ def __setstate__(self, d) -> object: self._load_client() def _load_client(self): - # Required stub implemented when extending `OpenAICompatible` - # should populate self.generator with an openai api compliant object - raise NotImplementedError + # When extending `OpenAICompatible` this method is a likely location for target application specific + # customization and must populate self.generator with an openai api compliant object + self.client = openai.OpenAI(base_url=self.uri, api_key=self.api_key) + if self.name in ("", None): + raise ValueError( + f"{self.generator_family_name} requires model name to be set, e.g. --model_name org/private-model-name" + ) + self.generator = self.client.chat.completions def _clear_client(self): - # Required stub implemented when extending `OpenAICompatible` - raise NotImplementedError + self.generator = None + self.client = None def _validate_config(self): pass @@ -257,6 +263,11 @@ class OpenAIGenerator(OpenAICompatible): active = True generator_family_name = "OpenAI" + # remove uri as it is not overridable in this class. + DEFAULT_PARAMS = { + k: val for k, val in OpenAICompatible.DEFAULT_PARAMS.items() if k != "uri" + } + def _load_client(self): self.client = openai.OpenAI(api_key=self.api_key) @@ -289,10 +300,6 @@ def _load_client(self): logging.error(msg) raise garak.exception.BadGeneratorException("🛑 " + msg) - def _clear_client(self): - self.generator = None - self.client = None - def __init__(self, name="", config_root=_config): self.name = name self._load_config(config_root) diff --git a/tests/generators/test_generators.py b/tests/generators/test_generators.py index 132dcee2..74c2a153 100644 --- a/tests/generators/test_generators.py +++ b/tests/generators/test_generators.py @@ -133,8 +133,12 @@ def test_parallel_requests(): result = g.generate(prompt="this is a test", generations_this_call=3) assert isinstance(result, list), "Generator generate() should return a list" assert len(result) == 3, "Generator should return 3 results as requested" - assert all(isinstance(item, str) for item in result), "All items in the generate result should be strings" - assert all(len(item) > 0 for item in result), "All generated strings should be non-empty" + assert all( + isinstance(item, str) for item in result + ), "All items in the generate result should be strings" + assert all( + len(item) > 0 for item in result + ), "All generated strings should be non-empty" @pytest.mark.parametrize("classname", GENERATORS) @@ -190,7 +194,6 @@ def test_generator_structure(classname): "generators.huggingface.OptimumPipeline", # model name restrictions and cuda required "generators.huggingface.Pipeline", # model name restrictions "generators.langchain.LangChainLLMGenerator", # model name restrictions - "generators.openai.OpenAICompatible", # template class not intended to ever be `Active` ] ]