From 9b1e4756d860553aa28672e443d2039cb01f53ec Mon Sep 17 00:00:00 2001 From: Leon Derczynski Date: Wed, 24 Apr 2024 10:05:02 +0200 Subject: [PATCH] add var for generator context_len and populate this for some generators (#616) * add context_len generator base attrib, and implement setting in some HF generators * add generator context lengths for openai models * factor out hf context len extraction to mixin --- garak/generators/base.py | 1 + garak/generators/huggingface.py | 33 +++++++++++++++++++++++++++------ garak/generators/openai.py | 26 ++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 6 deletions(-) diff --git a/garak/generators/base.py b/garak/generators/base.py index 990415d9e..a69fa2d97 100644 --- a/garak/generators/base.py +++ b/garak/generators/base.py @@ -23,6 +23,7 @@ class Generator: top_k = None active = True generator_family_name = None + context_len = None supports_multiple_generations = ( False # can more than one generation be extracted per request? diff --git a/garak/generators/huggingface.py b/garak/generators/huggingface.py index 6d5ca5b1f..76087b88c 100644 --- a/garak/generators/huggingface.py +++ b/garak/generators/huggingface.py @@ -43,12 +43,24 @@ class HFInternalServerError(Exception): pass -class Pipeline(Generator): +class HFCompatible: + def _set_hf_context_len(self, config): + if hasattr(config, "n_ctx"): + if isinstance(config.n_ctx, int): + self.context_len = config.n_ctx + + +class Pipeline(Generator, HFCompatible): """Get text generations from a locally-run Hugging Face pipeline""" generator_family_name = "Hugging Face 🤗 pipeline" supports_multiple_generations = True + def _set_hf_context_len(self, config): + if hasattr(config, "n_ctx"): + if isinstance(config.n_ctx, int): + self.context_len = config.n_ctx + def __init__(self, name, do_sample=True, generations=10, device=0): self.fullname, self.name = name, name.split("/")[-1] @@ -76,6 +88,8 @@ def __init__(self, name, do_sample=True, generations=10, device=0): if _config.run.deprefix is True: self.deprefix_prompt = True + self._set_hf_context_len(self.generator.model.config) + def _call_model(self, prompt: str) -> List[str]: with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) @@ -109,7 +123,7 @@ def _call_model(self, prompt: str) -> List[str]: return [re.sub("^" + re.escape(prompt), "", i) for i in generations] -class OptimumPipeline(Pipeline): +class OptimumPipeline(Pipeline, HFCompatible): """Get text generations from a locally-run Hugging Face pipeline using NVIDIA Optimum""" generator_family_name = "NVIDIA Optimum Hugging Face 🤗 pipeline" @@ -151,8 +165,10 @@ def __init__(self, name, do_sample=True, generations=10, device=0): if _config.run.deprefix is True: self.deprefix_prompt = True + self._set_hf_context_len(self.generator.model.config) -class ConversationalPipeline(Generator): + +class ConversationalPipeline(Generator, HFCompatible): """Conversational text generation using HuggingFace pipelines""" generator_family_name = "Hugging Face 🤗 pipeline for conversations" @@ -188,6 +204,8 @@ def __init__(self, name, do_sample=True, generations=10, device=0): if _config.run.deprefix is True: self.deprefix_prompt = True + self._set_hf_context_len(self.generator.model.config) + def clear_history(self): from transformers import Conversation @@ -222,7 +240,7 @@ def _call_model(self, prompt: Union[str, list[dict]]) -> List[str]: return [re.sub("^" + re.escape(prompt), "", i) for i in generations] -class InferenceAPI(Generator): +class InferenceAPI(Generator, HFCompatible): """Get text generations from Hugging Face Inference API""" generator_family_name = "Hugging Face 🤗 Inference API" @@ -340,7 +358,7 @@ def _pre_generate_hook(self): self.wait_for_model = False -class InferenceEndpoint(InferenceAPI): +class InferenceEndpoint(InferenceAPI, HFCompatible): """Interface for Hugging Face private endpoints Pass the model URL as the name, e.g. https://xxx.aws.endpoints.huggingface.cloud """ @@ -393,7 +411,7 @@ def _call_model(self, prompt: str) -> List[str]: return output -class Model(Generator): +class Model(Generator, HFCompatible): """Get text generations from a locally-run Hugging Face model""" generator_family_name = "Hugging Face 🤗 model" @@ -427,10 +445,13 @@ def __init__(self, name, do_sample=True, generations=10, device=0): self.init_device # or "cuda:0" For fast initialization directly on GPU! ) + self._set_hf_context_len(self.config) + self.model = transformers.AutoModelForCausalLM.from_pretrained( self.fullname, config=self.config, ).to(self.init_device) + self.deprefix_prompt = name in models_to_deprefix if self.config.tokenizer_class: diff --git a/garak/generators/openai.py b/garak/generators/openai.py index d3eb432b0..013635586 100644 --- a/garak/generators/openai.py +++ b/garak/generators/openai.py @@ -58,6 +58,29 @@ # "ada", # shutdown https://platform.openai.com/docs/deprecations ) +context_lengths = { + "gpt-3.5-turbo-0125": 16385, + "gpt-3.5-turbo": 16385, + "gpt-3.5-turbo-1106": 16385, + "gpt-3.5-turbo-instruct": 4096, + "gpt-3.5-turbo-16k": 16385, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k-0613": 16385, + "babbage-002": 16384, + "davinci-002": 16384, + "gpt-4-turbo": 128000, + "gpt-4-turbo-2024-04-09": 128000, + "gpt-4-turbo-preview": 128000, + "gpt-4-0125-preview": 128000, + "gpt-4-1106-preview": 128000, + "gpt-4-vision-preview": 128000, + "gpt-4-1106-vision-preview": 128000, + "gpt-4": 8192, + "gpt-4-0613": 8192, + "gpt-4-32k": 32768, + "gpt-4-32k-0613": 32768, +} + class OpenAIGenerator(Generator): """Generator wrapper for OpenAI text2text models. Expects API key in the OPENAI_API_KEY environment variable""" @@ -95,6 +118,9 @@ def __init__(self, name, generations=10): ): # handle model names -MMDDish suffix self.generator = self.client.completions + if self.name in context_lengths: + self.context_len = context_lengths[self.name] + elif self.name == "": openai_model_list = sorted([m.id for m in self.client.models.list().data]) raise ValueError(