Skip to content

Commit

Permalink
allow generators.Base.generate() to take an optional param specifying…
Browse files Browse the repository at this point in the history
… generation count (NVIDIA#600)

* generate now takes an optional  parameter specifying how many strings to get as output

* logging and value checking for generations_this_call

---------

Signed-off-by: Leon Derczynski <[email protected]>
  • Loading branch information
leondz authored Apr 24, 2024
1 parent 9b1e475 commit a36e276
Show file tree
Hide file tree
Showing 17 changed files with 100 additions and 63 deletions.
31 changes: 22 additions & 9 deletions garak/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def __init__(self, name="", generations=10):
)
logging.info("generator init: %s", self)

def _call_model(self, prompt: str) -> Union[List[str], str, None]:
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> Union[List[str], str, None]:
"""Takes a prompt and returns an API output
_call_api() is fully responsible for the request, and should either
Expand All @@ -63,7 +65,7 @@ def _pre_generate_hook(self):
def clear_history(self):
pass

def generate(self, prompt: str) -> List[str]:
def generate(self, prompt: str, generations_this_call: int = -1) -> List[str]:
"""Manages the process of getting generations out from a prompt
This will involve iterating through prompts, getting the generations
Expand All @@ -74,11 +76,22 @@ def generate(self, prompt: str) -> List[str]:

self._pre_generate_hook()

assert (
generations_this_call >= -1
), f"Unexpected value for generations_per_call: {generations_this_call}"

if generations_this_call == -1:
generations_this_call = self.generations

elif generations_this_call == 0:
logging.debug("generate() called with generations_this_call = 0")
return []

if self.supports_multiple_generations:
return self._call_model(prompt)
return self._call_model(prompt, generations_this_call)

elif self.generations <= 1:
return [self._call_model(prompt)]
elif generations_this_call <= 1:
return [self._call_model(prompt, generations_this_call)]

else:
outputs = []
Expand All @@ -90,23 +103,23 @@ def generate(self, prompt: str) -> List[str]:
):
from multiprocessing import Pool

bar = tqdm.tqdm(total=self.generations, leave=False)
bar = tqdm.tqdm(total=generations_this_call, leave=False)
bar.set_description(self.fullname[:55])

with Pool(_config.system.parallel_requests) as pool:
for result in pool.imap_unordered(
self._call_model, [prompt] * self.generations
self._call_model, [prompt] * generations_this_call
):
outputs.append(result)
bar.update(1)

else:
generation_iterator = tqdm.tqdm(
list(range(self.generations)), leave=False
list(range(generations_this_call)), leave=False
)
generation_iterator.set_description(self.fullname[:55])
for i in generation_iterator:
outputs.append(self._call_model(prompt))
outputs.append(self._call_model(prompt, generations_this_call))

cleaned_outputs = [
o for o in outputs if o is not None
Expand Down
7 changes: 5 additions & 2 deletions garak/generators/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import logging
import os
from typing import List, Union

import backoff
import cohere
Expand Down Expand Up @@ -81,10 +82,12 @@ def _call_cohere_api(self, prompt, request_size=COHERE_GENERATION_LIMIT):
)
return [g.text for g in response]

def _call_model(self, prompt):
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> Union[List[str], str, None]:
"""Cohere's _call_model does sub-batching before calling,
and so manages chunking internally"""
quotient, remainder = divmod(self.generations, COHERE_GENERATION_LIMIT)
quotient, remainder = divmod(generations_this_call, COHERE_GENERATION_LIMIT)
request_sizes = [COHERE_GENERATION_LIMIT] * quotient
if remainder:
request_sizes += [remainder]
Expand Down
16 changes: 11 additions & 5 deletions garak/generators/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


import importlib
from typing import List
from typing import List, Union

from garak.generators.base import Generator

Expand All @@ -56,17 +56,23 @@ def __init__(self, name="", **kwargs): # name="", generations=self.generations)

super().__init__(name, generations=self.generations)

def _call_model(self, prompt: str) -> str:
return self.generator(prompt, **self.kwargs)
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> Union[List[str], str, None]:
return self.generator(
prompt, generations_this_call=generations_this_call, **self.kwargs
)


class Multiple(Single):
"""pass a module#function to be called as generator, with format function(prompt:str, generations:int, **kwargs)->List[str]"""

supports_multiple_generations = True

def _call_model(self, prompt) -> List[str]:
return self.generator(prompt, generations=self.generations, **self.kwargs)
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> Union[List[str], str, None]:
return self.generator(prompt, generations=generations_this_call, **self.kwargs)


default_class = "Single"
10 changes: 9 additions & 1 deletion garak/generators/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import re
import subprocess
from typing import List, Union

from garak import _config
from garak.generators.base import Generator
Expand Down Expand Up @@ -77,7 +78,14 @@ def __init__(self, name, generations=10):

super().__init__(name, generations=generations)

def _call_model(self, prompt):
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> Union[List[str], str, None]:
if generations_this_call != 1:
logging.warning(
"GgmlGenerator._call_model invokes with generations_this_call=%s but only 1 supported",
generations_this_call,
)
command = [
self.path_to_ggml_main,
"-p",
Expand Down
6 changes: 4 additions & 2 deletions garak/generators/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from contextlib import redirect_stderr
import io
from typing import List
from typing import List, Union

from garak.generators.base import Generator

Expand Down Expand Up @@ -35,7 +35,9 @@ def __init__(self, name, generations=1):

super().__init__(name, generations=generations)

def _call_model(self, prompt: str) -> List[str]:
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> Union[List[str], str, None]:
with redirect_stderr(io.StringIO()) as f: # quieten the tqdm
result = self.rails.generate(prompt)

Expand Down
49 changes: 26 additions & 23 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def __init__(self, name, do_sample=True, generations=10, device=0):
if _config.run.deprefix is True:
self.deprefix_prompt = True

self._set_hf_context_len(self.generator.model.config)
self._set_hf_context_len(self.generator.model.config)

def _call_model(self, prompt: str) -> List[str]:
def _call_model(self, prompt: str, generations_this_call: int = 1) -> List[str]:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
try:
Expand All @@ -104,23 +104,22 @@ def _call_model(self, prompt: str) -> List[str]:
truncated_prompt,
pad_token_id=self.generator.tokenizer.eos_token_id,
max_new_tokens=self.max_tokens,
num_return_sequences=self.generations,
num_return_sequences=generations_this_call,
)
except Exception as e:
logging.error(e)
raw_output = [] # could handle better than this

outputs = []
if raw_output is not None:
generations = [
outputs = [
i["generated_text"] for i in raw_output
] # generator returns 10 outputs by default in __init__
else:
generations = []

if not self.deprefix_prompt:
return generations
return outputs
else:
return [re.sub("^" + re.escape(prompt), "", i) for i in generations]
return [re.sub("^" + re.escape(prompt), "", _o) for _o in outputs]


class OptimumPipeline(Pipeline, HFCompatible):
Expand Down Expand Up @@ -211,7 +210,9 @@ def clear_history(self):

self.conversation = Conversation()

def _call_model(self, prompt: Union[str, list[dict]]) -> List[str]:
def _call_model(
self, prompt: Union[str, list[dict]], generations_this_call: int = 1
) -> List[str]:
"""Take a conversation as a list of dictionaries and feed it to the model"""

# If conversation is provided as a list of dicts, create the conversation.
Expand All @@ -230,14 +231,14 @@ def _call_model(self, prompt: Union[str, list[dict]]) -> List[str]:
with torch.no_grad():
conversation = self.generator(conversation)

generations = [conversation[-1]["content"]]
outputs = [conversation[-1]["content"]]
else:
raise TypeError(f"Expected list or str, got {type(prompt)}")

if not self.deprefix_prompt:
return generations
return outputs
else:
return [re.sub("^" + re.escape(prompt), "", i) for i in generations]
return [re.sub("^" + re.escape(prompt), "", _o) for _o in outputs]


class InferenceAPI(Generator, HFCompatible):
Expand Down Expand Up @@ -275,15 +276,15 @@ def __init__(self, name="", generations=10):
),
max_value=125,
)
def _call_model(self, prompt: str) -> List[str]:
def _call_model(self, prompt: str, generations_this_call: int = 1) -> List[str]:
import json
import requests

payload = {
"inputs": prompt,
"parameters": {
"return_full_text": not self.deprefix_prompt,
"num_return_sequences": self.generations,
"num_return_sequences": generations_this_call,
"max_time": self.max_time,
},
"options": {
Expand All @@ -293,7 +294,7 @@ def _call_model(self, prompt: str) -> List[str]:
if self.max_tokens:
payload["parameters"]["max_new_tokens"] = self.max_tokens

if self.generations > 1:
if generations_this_call > 1:
payload["parameters"]["do_sample"] = True

req_response = requests.request(
Expand Down Expand Up @@ -366,6 +367,8 @@ class InferenceEndpoint(InferenceAPI, HFCompatible):
supports_multiple_generations = False
import requests

timeout = 120

def __init__(self, name="", generations=10):
super().__init__(name, generations=generations)
self.api_url = name
Expand All @@ -380,7 +383,7 @@ def __init__(self, name="", generations=10):
),
max_value=125,
)
def _call_model(self, prompt: str) -> List[str]:
def _call_model(self, prompt: str, generations_this_call: int = 1) -> List[str]:
import requests

payload = {
Expand All @@ -396,18 +399,18 @@ def _call_model(self, prompt: str) -> List[str]:
if self.max_tokens:
payload["parameters"]["max_new_tokens"] = self.max_tokens

if self.generations > 1:
if generations_this_call > 1:
payload["parameters"]["do_sample"] = True

response = requests.post(
self.api_url, headers=self.headers, json=payload
self.api_url, headers=self.headers, json=payload, timeout=self.timeout
).json()
try:
output = response[0]["generated_text"]
except:
except Exception as exc:
raise IOError(
"Hugging Face 🤗 endpoint didn't generate a response. Make sure the endpoint is active."
)
) from exc
return output


Expand Down Expand Up @@ -471,10 +474,10 @@ def __init__(self, name, do_sample=True, generations=10, device=0):
self.generation_config.eos_token_id = self.model.config.eos_token_id
self.generation_config.pad_token_id = self.model.config.eos_token_id

def _call_model(self, prompt):
def _call_model(self, prompt: str, generations_this_call: int = 1):
self.generation_config.max_new_tokens = self.max_tokens
self.generation_config.do_sample = self.do_sample
self.generation_config.num_return_sequences = self.generations
self.generation_config.num_return_sequences = generations_this_call
if self.temperature is not None:
self.generation_config.temperature = self.temperature
if self.top_k is not None:
Expand All @@ -494,7 +497,7 @@ def _call_model(self, prompt):
)
except IndexError as e:
if len(prompt) == 0:
return [""] * self.generations
return [""] * generations_this_call
else:
raise e
text_output = self.tokenizer.batch_decode(
Expand Down
2 changes: 1 addition & 1 deletion garak/generators/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, name, generations=10):

self.generator = llm

def _call_model(self, prompt: str) -> str:
def _call_model(self, prompt: str, generations_this_call: int = 1) -> str:
"""
Continuation generation method for LangChain LLM integrations.
Expand Down
4 changes: 2 additions & 2 deletions garak/generators/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(self, name: str, generations: int = 10):

@backoff.on_exception(backoff.fibo, Exception, max_value=70)
def _call_model(
self, prompt: Union[str, List[dict]]
self, prompt: str, generations_this_call: int = 1
) -> Union[List[str], str, None]:
if isinstance(prompt, str):
prompt = [{"role": "user", "content": prompt}]
Expand All @@ -155,7 +155,7 @@ def _call_model(
messages=prompt,
temperature=self.temperature,
top_p=self.top_p,
n=self.generations,
n=generations_this_call,
stop=self.stop,
max_tokens=self.max_tokens,
frequency_penalty=self.frequency_penalty,
Expand Down
4 changes: 2 additions & 2 deletions garak/generators/nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, name=None, generations=10):
),
max_value=70,
)
def _call_model(self, prompt):
def _call_model(self, prompt: str, generations_this_call: int = 1):
# avoid:
# doesn't match schema #/components/schemas/CompletionRequestBody: Error at "/prompt": minimum string length is 1
if prompt == "":
Expand All @@ -80,7 +80,7 @@ def _call_model(self, prompt):
if self.seed is None: # nemo gives the same result every time
reset_none_seed = True
self.seed = random.randint(0, 2147483648 - 1)
elif self.generations > 1:
elif generations_this_call > 1:
logging.info(
"fixing a seed means nemollm gives the same result every time, recommend setting generations=1"
)
Expand Down
2 changes: 1 addition & 1 deletion garak/generators/nvcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, name=None, generations=10):
),
max_value=70,
)
def _call_model(self, prompt: str) -> str:
def _call_model(self, prompt: str, generations_this_call: int = 1) -> str:
if prompt == "":
return ""

Expand Down
Loading

0 comments on commit a36e276

Please sign in to comment.