Skip to content

Commit

Permalink
add var for generator context_len and populate this for some generato…
Browse files Browse the repository at this point in the history
…rs (NVIDIA#616)

* add context_len generator base attrib, and implement setting in some HF generators

* add generator context lengths for openai models

* factor out hf context len extraction to mixin
  • Loading branch information
leondz authored Apr 24, 2024
1 parent 148f0f1 commit 9b1e475
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 6 deletions.
1 change: 1 addition & 0 deletions garak/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Generator:
top_k = None
active = True
generator_family_name = None
context_len = None

supports_multiple_generations = (
False # can more than one generation be extracted per request?
Expand Down
33 changes: 27 additions & 6 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,24 @@ class HFInternalServerError(Exception):
pass


class Pipeline(Generator):
class HFCompatible:
def _set_hf_context_len(self, config):
if hasattr(config, "n_ctx"):
if isinstance(config.n_ctx, int):
self.context_len = config.n_ctx


class Pipeline(Generator, HFCompatible):
"""Get text generations from a locally-run Hugging Face pipeline"""

generator_family_name = "Hugging Face 🤗 pipeline"
supports_multiple_generations = True

def _set_hf_context_len(self, config):
if hasattr(config, "n_ctx"):
if isinstance(config.n_ctx, int):
self.context_len = config.n_ctx

def __init__(self, name, do_sample=True, generations=10, device=0):
self.fullname, self.name = name, name.split("/")[-1]

Expand Down Expand Up @@ -76,6 +88,8 @@ def __init__(self, name, do_sample=True, generations=10, device=0):
if _config.run.deprefix is True:
self.deprefix_prompt = True

self._set_hf_context_len(self.generator.model.config)

def _call_model(self, prompt: str) -> List[str]:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
Expand Down Expand Up @@ -109,7 +123,7 @@ def _call_model(self, prompt: str) -> List[str]:
return [re.sub("^" + re.escape(prompt), "", i) for i in generations]


class OptimumPipeline(Pipeline):
class OptimumPipeline(Pipeline, HFCompatible):
"""Get text generations from a locally-run Hugging Face pipeline using NVIDIA Optimum"""

generator_family_name = "NVIDIA Optimum Hugging Face 🤗 pipeline"
Expand Down Expand Up @@ -151,8 +165,10 @@ def __init__(self, name, do_sample=True, generations=10, device=0):
if _config.run.deprefix is True:
self.deprefix_prompt = True

self._set_hf_context_len(self.generator.model.config)

class ConversationalPipeline(Generator):

class ConversationalPipeline(Generator, HFCompatible):
"""Conversational text generation using HuggingFace pipelines"""

generator_family_name = "Hugging Face 🤗 pipeline for conversations"
Expand Down Expand Up @@ -188,6 +204,8 @@ def __init__(self, name, do_sample=True, generations=10, device=0):
if _config.run.deprefix is True:
self.deprefix_prompt = True

self._set_hf_context_len(self.generator.model.config)

def clear_history(self):
from transformers import Conversation

Expand Down Expand Up @@ -222,7 +240,7 @@ def _call_model(self, prompt: Union[str, list[dict]]) -> List[str]:
return [re.sub("^" + re.escape(prompt), "", i) for i in generations]


class InferenceAPI(Generator):
class InferenceAPI(Generator, HFCompatible):
"""Get text generations from Hugging Face Inference API"""

generator_family_name = "Hugging Face 🤗 Inference API"
Expand Down Expand Up @@ -340,7 +358,7 @@ def _pre_generate_hook(self):
self.wait_for_model = False


class InferenceEndpoint(InferenceAPI):
class InferenceEndpoint(InferenceAPI, HFCompatible):
"""Interface for Hugging Face private endpoints
Pass the model URL as the name, e.g. https://xxx.aws.endpoints.huggingface.cloud
"""
Expand Down Expand Up @@ -393,7 +411,7 @@ def _call_model(self, prompt: str) -> List[str]:
return output


class Model(Generator):
class Model(Generator, HFCompatible):
"""Get text generations from a locally-run Hugging Face model"""

generator_family_name = "Hugging Face 🤗 model"
Expand Down Expand Up @@ -427,10 +445,13 @@ def __init__(self, name, do_sample=True, generations=10, device=0):
self.init_device # or "cuda:0" For fast initialization directly on GPU!
)

self._set_hf_context_len(self.config)

self.model = transformers.AutoModelForCausalLM.from_pretrained(
self.fullname,
config=self.config,
).to(self.init_device)

self.deprefix_prompt = name in models_to_deprefix

if self.config.tokenizer_class:
Expand Down
26 changes: 26 additions & 0 deletions garak/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,29 @@
# "ada", # shutdown https://platform.openai.com/docs/deprecations
)

context_lengths = {
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo": 16385,
"gpt-3.5-turbo-1106": 16385,
"gpt-3.5-turbo-instruct": 4096,
"gpt-3.5-turbo-16k": 16385,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k-0613": 16385,
"babbage-002": 16384,
"davinci-002": 16384,
"gpt-4-turbo": 128000,
"gpt-4-turbo-2024-04-09": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-4-0125-preview": 128000,
"gpt-4-1106-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4-1106-vision-preview": 128000,
"gpt-4": 8192,
"gpt-4-0613": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
}


class OpenAIGenerator(Generator):
"""Generator wrapper for OpenAI text2text models. Expects API key in the OPENAI_API_KEY environment variable"""
Expand Down Expand Up @@ -95,6 +118,9 @@ def __init__(self, name, generations=10):
): # handle model names -MMDDish suffix
self.generator = self.client.completions

if self.name in context_lengths:
self.context_len = context_lengths[self.name]

elif self.name == "":
openai_model_list = sorted([m.id for m in self.client.models.list().data])
raise ValueError(
Expand Down

0 comments on commit 9b1e475

Please sign in to comment.