diff --git a/garak/generators/base.py b/garak/generators/base.py index a69fa2d97..47559ca90 100644 --- a/garak/generators/base.py +++ b/garak/generators/base.py @@ -47,7 +47,9 @@ def __init__(self, name="", generations=10): ) logging.info("generator init: %s", self) - def _call_model(self, prompt: str) -> Union[List[str], str, None]: + def _call_model( + self, prompt: str, generations_this_call: int = 1 + ) -> Union[List[str], str, None]: """Takes a prompt and returns an API output _call_api() is fully responsible for the request, and should either @@ -63,7 +65,7 @@ def _pre_generate_hook(self): def clear_history(self): pass - def generate(self, prompt: str) -> List[str]: + def generate(self, prompt: str, generations_this_call: int = -1) -> List[str]: """Manages the process of getting generations out from a prompt This will involve iterating through prompts, getting the generations @@ -74,11 +76,22 @@ def generate(self, prompt: str) -> List[str]: self._pre_generate_hook() + assert ( + generations_this_call >= -1 + ), f"Unexpected value for generations_per_call: {generations_this_call}" + + if generations_this_call == -1: + generations_this_call = self.generations + + elif generations_this_call == 0: + logging.debug("generate() called with generations_this_call = 0") + return [] + if self.supports_multiple_generations: - return self._call_model(prompt) + return self._call_model(prompt, generations_this_call) - elif self.generations <= 1: - return [self._call_model(prompt)] + elif generations_this_call <= 1: + return [self._call_model(prompt, generations_this_call)] else: outputs = [] @@ -90,23 +103,23 @@ def generate(self, prompt: str) -> List[str]: ): from multiprocessing import Pool - bar = tqdm.tqdm(total=self.generations, leave=False) + bar = tqdm.tqdm(total=generations_this_call, leave=False) bar.set_description(self.fullname[:55]) with Pool(_config.system.parallel_requests) as pool: for result in pool.imap_unordered( - self._call_model, [prompt] * self.generations + self._call_model, [prompt] * generations_this_call ): outputs.append(result) bar.update(1) else: generation_iterator = tqdm.tqdm( - list(range(self.generations)), leave=False + list(range(generations_this_call)), leave=False ) generation_iterator.set_description(self.fullname[:55]) for i in generation_iterator: - outputs.append(self._call_model(prompt)) + outputs.append(self._call_model(prompt, generations_this_call)) cleaned_outputs = [ o for o in outputs if o is not None diff --git a/garak/generators/cohere.py b/garak/generators/cohere.py index 5e6eaf40c..76d1783ab 100644 --- a/garak/generators/cohere.py +++ b/garak/generators/cohere.py @@ -9,6 +9,7 @@ import logging import os +from typing import List, Union import backoff import cohere @@ -81,10 +82,12 @@ def _call_cohere_api(self, prompt, request_size=COHERE_GENERATION_LIMIT): ) return [g.text for g in response] - def _call_model(self, prompt): + def _call_model( + self, prompt: str, generations_this_call: int = 1 + ) -> Union[List[str], str, None]: """Cohere's _call_model does sub-batching before calling, and so manages chunking internally""" - quotient, remainder = divmod(self.generations, COHERE_GENERATION_LIMIT) + quotient, remainder = divmod(generations_this_call, COHERE_GENERATION_LIMIT) request_sizes = [COHERE_GENERATION_LIMIT] * quotient if remainder: request_sizes += [remainder] diff --git a/garak/generators/function.py b/garak/generators/function.py index 5f3c6ca72..f77cfd678 100644 --- a/garak/generators/function.py +++ b/garak/generators/function.py @@ -31,7 +31,7 @@ import importlib -from typing import List +from typing import List, Union from garak.generators.base import Generator @@ -56,8 +56,12 @@ def __init__(self, name="", **kwargs): # name="", generations=self.generations) super().__init__(name, generations=self.generations) - def _call_model(self, prompt: str) -> str: - return self.generator(prompt, **self.kwargs) + def _call_model( + self, prompt: str, generations_this_call: int = 1 + ) -> Union[List[str], str, None]: + return self.generator( + prompt, generations_this_call=generations_this_call, **self.kwargs + ) class Multiple(Single): @@ -65,8 +69,10 @@ class Multiple(Single): supports_multiple_generations = True - def _call_model(self, prompt) -> List[str]: - return self.generator(prompt, generations=self.generations, **self.kwargs) + def _call_model( + self, prompt: str, generations_this_call: int = 1 + ) -> Union[List[str], str, None]: + return self.generator(prompt, generations=generations_this_call, **self.kwargs) default_class = "Single" diff --git a/garak/generators/ggml.py b/garak/generators/ggml.py index ed6d5acdc..cc4415321 100644 --- a/garak/generators/ggml.py +++ b/garak/generators/ggml.py @@ -16,6 +16,7 @@ import os import re import subprocess +from typing import List, Union from garak import _config from garak.generators.base import Generator @@ -77,7 +78,14 @@ def __init__(self, name, generations=10): super().__init__(name, generations=generations) - def _call_model(self, prompt): + def _call_model( + self, prompt: str, generations_this_call: int = 1 + ) -> Union[List[str], str, None]: + if generations_this_call != 1: + logging.warning( + "GgmlGenerator._call_model invokes with generations_this_call=%s but only 1 supported", + generations_this_call, + ) command = [ self.path_to_ggml_main, "-p", diff --git a/garak/generators/guardrails.py b/garak/generators/guardrails.py index 3468fd0df..3c0512129 100644 --- a/garak/generators/guardrails.py +++ b/garak/generators/guardrails.py @@ -5,7 +5,7 @@ from contextlib import redirect_stderr import io -from typing import List +from typing import List, Union from garak.generators.base import Generator @@ -35,7 +35,9 @@ def __init__(self, name, generations=1): super().__init__(name, generations=generations) - def _call_model(self, prompt: str) -> List[str]: + def _call_model( + self, prompt: str, generations_this_call: int = 1 + ) -> Union[List[str], str, None]: with redirect_stderr(io.StringIO()) as f: # quieten the tqdm result = self.rails.generate(prompt) diff --git a/garak/generators/huggingface.py b/garak/generators/huggingface.py index 76087b88c..11ea86e2a 100644 --- a/garak/generators/huggingface.py +++ b/garak/generators/huggingface.py @@ -88,9 +88,9 @@ def __init__(self, name, do_sample=True, generations=10, device=0): if _config.run.deprefix is True: self.deprefix_prompt = True - self._set_hf_context_len(self.generator.model.config) + self._set_hf_context_len(self.generator.model.config) - def _call_model(self, prompt: str) -> List[str]: + def _call_model(self, prompt: str, generations_this_call: int = 1) -> List[str]: with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) try: @@ -104,23 +104,22 @@ def _call_model(self, prompt: str) -> List[str]: truncated_prompt, pad_token_id=self.generator.tokenizer.eos_token_id, max_new_tokens=self.max_tokens, - num_return_sequences=self.generations, + num_return_sequences=generations_this_call, ) except Exception as e: logging.error(e) raw_output = [] # could handle better than this + outputs = [] if raw_output is not None: - generations = [ + outputs = [ i["generated_text"] for i in raw_output ] # generator returns 10 outputs by default in __init__ - else: - generations = [] if not self.deprefix_prompt: - return generations + return outputs else: - return [re.sub("^" + re.escape(prompt), "", i) for i in generations] + return [re.sub("^" + re.escape(prompt), "", _o) for _o in outputs] class OptimumPipeline(Pipeline, HFCompatible): @@ -211,7 +210,9 @@ def clear_history(self): self.conversation = Conversation() - def _call_model(self, prompt: Union[str, list[dict]]) -> List[str]: + def _call_model( + self, prompt: Union[str, list[dict]], generations_this_call: int = 1 + ) -> List[str]: """Take a conversation as a list of dictionaries and feed it to the model""" # If conversation is provided as a list of dicts, create the conversation. @@ -230,14 +231,14 @@ def _call_model(self, prompt: Union[str, list[dict]]) -> List[str]: with torch.no_grad(): conversation = self.generator(conversation) - generations = [conversation[-1]["content"]] + outputs = [conversation[-1]["content"]] else: raise TypeError(f"Expected list or str, got {type(prompt)}") if not self.deprefix_prompt: - return generations + return outputs else: - return [re.sub("^" + re.escape(prompt), "", i) for i in generations] + return [re.sub("^" + re.escape(prompt), "", _o) for _o in outputs] class InferenceAPI(Generator, HFCompatible): @@ -275,7 +276,7 @@ def __init__(self, name="", generations=10): ), max_value=125, ) - def _call_model(self, prompt: str) -> List[str]: + def _call_model(self, prompt: str, generations_this_call: int = 1) -> List[str]: import json import requests @@ -283,7 +284,7 @@ def _call_model(self, prompt: str) -> List[str]: "inputs": prompt, "parameters": { "return_full_text": not self.deprefix_prompt, - "num_return_sequences": self.generations, + "num_return_sequences": generations_this_call, "max_time": self.max_time, }, "options": { @@ -293,7 +294,7 @@ def _call_model(self, prompt: str) -> List[str]: if self.max_tokens: payload["parameters"]["max_new_tokens"] = self.max_tokens - if self.generations > 1: + if generations_this_call > 1: payload["parameters"]["do_sample"] = True req_response = requests.request( @@ -366,6 +367,8 @@ class InferenceEndpoint(InferenceAPI, HFCompatible): supports_multiple_generations = False import requests + timeout = 120 + def __init__(self, name="", generations=10): super().__init__(name, generations=generations) self.api_url = name @@ -380,7 +383,7 @@ def __init__(self, name="", generations=10): ), max_value=125, ) - def _call_model(self, prompt: str) -> List[str]: + def _call_model(self, prompt: str, generations_this_call: int = 1) -> List[str]: import requests payload = { @@ -396,18 +399,18 @@ def _call_model(self, prompt: str) -> List[str]: if self.max_tokens: payload["parameters"]["max_new_tokens"] = self.max_tokens - if self.generations > 1: + if generations_this_call > 1: payload["parameters"]["do_sample"] = True response = requests.post( - self.api_url, headers=self.headers, json=payload + self.api_url, headers=self.headers, json=payload, timeout=self.timeout ).json() try: output = response[0]["generated_text"] - except: + except Exception as exc: raise IOError( "Hugging Face 🤗 endpoint didn't generate a response. Make sure the endpoint is active." - ) + ) from exc return output @@ -471,10 +474,10 @@ def __init__(self, name, do_sample=True, generations=10, device=0): self.generation_config.eos_token_id = self.model.config.eos_token_id self.generation_config.pad_token_id = self.model.config.eos_token_id - def _call_model(self, prompt): + def _call_model(self, prompt: str, generations_this_call: int = 1): self.generation_config.max_new_tokens = self.max_tokens self.generation_config.do_sample = self.do_sample - self.generation_config.num_return_sequences = self.generations + self.generation_config.num_return_sequences = generations_this_call if self.temperature is not None: self.generation_config.temperature = self.temperature if self.top_k is not None: @@ -494,7 +497,7 @@ def _call_model(self, prompt): ) except IndexError as e: if len(prompt) == 0: - return [""] * self.generations + return [""] * generations_this_call else: raise e text_output = self.tokenizer.batch_decode( diff --git a/garak/generators/langchain.py b/garak/generators/langchain.py index 5d10b30d7..9c9332783 100644 --- a/garak/generators/langchain.py +++ b/garak/generators/langchain.py @@ -56,7 +56,7 @@ def __init__(self, name, generations=10): self.generator = llm - def _call_model(self, prompt: str) -> str: + def _call_model(self, prompt: str, generations_this_call: int = 1) -> str: """ Continuation generation method for LangChain LLM integrations. diff --git a/garak/generators/litellm.py b/garak/generators/litellm.py index 4511692cb..b6367f615 100644 --- a/garak/generators/litellm.py +++ b/garak/generators/litellm.py @@ -135,7 +135,7 @@ def __init__(self, name: str, generations: int = 10): @backoff.on_exception(backoff.fibo, Exception, max_value=70) def _call_model( - self, prompt: Union[str, List[dict]] + self, prompt: str, generations_this_call: int = 1 ) -> Union[List[str], str, None]: if isinstance(prompt, str): prompt = [{"role": "user", "content": prompt}] @@ -155,7 +155,7 @@ def _call_model( messages=prompt, temperature=self.temperature, top_p=self.top_p, - n=self.generations, + n=generations_this_call, stop=self.stop, max_tokens=self.max_tokens, frequency_penalty=self.frequency_penalty, diff --git a/garak/generators/nemo.py b/garak/generators/nemo.py index 694e170b6..dadc3fd3f 100644 --- a/garak/generators/nemo.py +++ b/garak/generators/nemo.py @@ -70,7 +70,7 @@ def __init__(self, name=None, generations=10): ), max_value=70, ) - def _call_model(self, prompt): + def _call_model(self, prompt: str, generations_this_call: int = 1): # avoid: # doesn't match schema #/components/schemas/CompletionRequestBody: Error at "/prompt": minimum string length is 1 if prompt == "": @@ -80,7 +80,7 @@ def _call_model(self, prompt): if self.seed is None: # nemo gives the same result every time reset_none_seed = True self.seed = random.randint(0, 2147483648 - 1) - elif self.generations > 1: + elif generations_this_call > 1: logging.info( "fixing a seed means nemollm gives the same result every time, recommend setting generations=1" ) diff --git a/garak/generators/nvcf.py b/garak/generators/nvcf.py index 41d9599d8..f8a3db4da 100644 --- a/garak/generators/nvcf.py +++ b/garak/generators/nvcf.py @@ -64,7 +64,7 @@ def __init__(self, name=None, generations=10): ), max_value=70, ) - def _call_model(self, prompt: str) -> str: + def _call_model(self, prompt: str, generations_this_call: int = 1) -> str: if prompt == "": return "" diff --git a/garak/generators/octo.py b/garak/generators/octo.py index f22ebf502..d76e17309 100644 --- a/garak/generators/octo.py +++ b/garak/generators/octo.py @@ -48,7 +48,7 @@ def __init__(self, name, generations=10): self.client = Client(token=octoai_token) @backoff.on_exception(backoff.fibo, octoai.errors.OctoAIServerError, max_value=70) - def _call_model(self, prompt): + def _call_model(self, prompt: str, generations_this_call: int = 1): outputs = self.client.chat.completions.create( messages=[ { @@ -84,7 +84,7 @@ def __init__(self, name, generations=10): ) @backoff.on_exception(backoff.fibo, octoai.errors.OctoAIServerError, max_value=70) - def _call_model(self, prompt): + def _call_model(self, prompt: str, generations_this_call: int = 1): outputs = self.client.infer( endpoint_url=self.name, inputs={ diff --git a/garak/generators/openai.py b/garak/generators/openai.py index 013635586..191547b4e 100644 --- a/garak/generators/openai.py +++ b/garak/generators/openai.py @@ -147,7 +147,9 @@ def __init__(self, name, generations=10): ), max_value=70, ) - def _call_model(self, prompt: Union[str, list[dict]]) -> List[str]: + def _call_model( + self, prompt: Union[str, list[dict]], generations_this_call: int = 1 + ) -> List[str]: if self.generator == self.client.completions: if not isinstance(prompt, str): msg = ( @@ -163,7 +165,7 @@ def _call_model(self, prompt: Union[str, list[dict]]) -> List[str]: prompt=prompt, temperature=self.temperature, max_tokens=self.max_tokens, - n=self.generations, + n=generations_this_call, top_p=self.top_p, frequency_penalty=self.frequency_penalty, presence_penalty=self.presence_penalty, @@ -189,7 +191,7 @@ def _call_model(self, prompt: Union[str, list[dict]]) -> List[str]: messages=messages, temperature=self.temperature, top_p=self.top_p, - n=self.generations, + n=generations_this_call, stop=self.stop, max_tokens=self.max_tokens, presence_penalty=self.presence_penalty, diff --git a/garak/generators/openai_v0.py b/garak/generators/openai_v0.py index 43acac259..ba07fb871 100644 --- a/garak/generators/openai_v0.py +++ b/garak/generators/openai_v0.py @@ -133,14 +133,14 @@ def __init__(self, name, generations=10): ), max_value=70, ) - def _call_model(self, prompt: str) -> List[str]: + def _call_model(self, prompt: str, generations_this_call: int = 1) -> List[str]: if self.generator == openai.Completion: response = self.generator.create( model=self.name, prompt=prompt, temperature=self.temperature, max_tokens=self.max_tokens, - n=self.generations, + n=generations_this_call, top_p=self.top_p, frequency_penalty=self.frequency_penalty, presence_penalty=self.presence_penalty, @@ -154,7 +154,7 @@ def _call_model(self, prompt: str) -> List[str]: messages=[{"role": "user", "content": prompt}], temperature=self.temperature, top_p=self.top_p, - n=self.generations, + n=generations_this_call, stop=self.stop, max_tokens=self.max_tokens, presence_penalty=self.presence_penalty, diff --git a/garak/generators/rasa.py b/garak/generators/rasa.py index e4f572927..3515c657f 100644 --- a/garak/generators/rasa.py +++ b/garak/generators/rasa.py @@ -161,7 +161,7 @@ def __init__(self, uri=None, generations=10): # we'll overload IOError for recoverable server errors @backoff.on_exception(backoff.fibo, (RESTRateLimitError, IOError), max_value=70) - def _call_model(self, prompt): + def _call_model(self, prompt: str, generations_this_call: int = 1): """Individual call to get a rest from the REST API :param prompt: the input to be placed into the request template and sent to the endpoint diff --git a/garak/generators/replicate.py b/garak/generators/replicate.py index 5d164bb46..0ce3579b9 100644 --- a/garak/generators/replicate.py +++ b/garak/generators/replicate.py @@ -50,7 +50,7 @@ def __init__(self, name, generations=10): @backoff.on_exception( backoff.fibo, replicate.exceptions.ReplicateError, max_value=70 ) - def _call_model(self, prompt): + def _call_model(self, prompt: str, generations_this_call: int = 1): response_iterator = self.replicate.run( self.name, input={ diff --git a/garak/generators/rest.py b/garak/generators/rest.py index b54a3ffe2..614841841 100644 --- a/garak/generators/rest.py +++ b/garak/generators/rest.py @@ -210,7 +210,7 @@ def _populate_template( # we'll overload IOError as the rate limit exception @backoff.on_exception(backoff.fibo, RESTRateLimitError, max_value=70) - def _call_model(self, prompt): + def _call_model(self, prompt: str, generations_this_call: int = 1): """Individual call to get a rest from the REST API :param prompt: the input to be placed into the request template and sent to the endpoint diff --git a/garak/generators/test.py b/garak/generators/test.py index b1e2492c9..ed02448a2 100644 --- a/garak/generators/test.py +++ b/garak/generators/test.py @@ -15,8 +15,8 @@ class Blank(Generator): generator_family_name = "Test" name = "Blank" - def _call_model(self, prompt: str) -> List[str]: - return [""] * self.generations + def _call_model(self, prompt: str, generations_this_call: int = 1) -> List[str]: + return [""] * generations_this_call class Repeat(Generator): @@ -26,8 +26,8 @@ class Repeat(Generator): generator_family_name = "Test" name = "Repeat" - def _call_model(self, prompt: str) -> List[str]: - return [prompt] * self.generations + def _call_model(self, prompt: str, generations_this_call: int = 1) -> List[str]: + return [prompt] * generations_this_call default_class = "Blank"