Skip to content

Commit

Permalink
paraphrase fast consistent model device (NVIDIA#898)
Browse files Browse the repository at this point in the history
When running tests on a platform with a `cuda` capable device, errors
can be reported due to how the model is instantiated.

Recent changes to support a configuration of plugins and huggingface
models as used in this class lead to further identification of possible
exceptions when processing a `HFCompatible` class that does not have
`hf_args` or when providing options for huggingface model constructors
that favor the newer `device_map` argument to `device`.
  • Loading branch information
jmartin-tech authored Sep 9, 2024
2 parents ff4ee33 + 230e6bf commit ee22b5d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
23 changes: 15 additions & 8 deletions garak/buffs/paraphrase.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import garak.attempt
from garak import _config
from garak.generators.huggingface import HFCompatible
from garak.buffs.base import Buff


Expand Down Expand Up @@ -69,14 +70,17 @@ def transform(
yield paraphrased_attempt


class Fast(Buff):
class Fast(Buff, HFCompatible):
"""CPU-friendly paraphrase buff based on Humarin's T5 paraphraser"""

DEFAULT_PARAMS = Buff.DEFAULT_PARAMS | {
"para_model_name": "garak-llm/chatgpt_paraphraser_on_T5_base",
"hf_args": {"device": "cpu", "torch_dtype": "float32"},
}
bcp47 = "en"
doc_uri = "https://huggingface.co/humarin/chatgpt_paraphraser_on_T5_base"

def __init__(self, config_root=_config) -> None:
self.para_model_name = "garak-llm/chatgpt_paraphraser_on_T5_base"
self.num_beams = 5
self.num_beam_groups = 5
self.num_return_sequences = 5
Expand All @@ -85,20 +89,23 @@ def __init__(self, config_root=_config) -> None:
self.no_repeat_ngram_size = 2
# self.temperature = 0.7
self.max_length = 128
self.torch_device = None
self.device = None
self.tokenizer = None
self.para_model = None
super().__init__(config_root=config_root)

def _load_model(self):
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

self.torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(self.para_model_name)
self.device = self._select_hf_device()
model_kwargs = self._gather_hf_params(
hf_constructor=AutoModelForSeq2SeqLM.from_pretrained
) # will defer to device_map if device map was `auto` may not match self.device

self.para_model = AutoModelForSeq2SeqLM.from_pretrained(
self.para_model_name
).to(self.torch_device)
self.para_model_name, **model_kwargs
).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.para_model_name)

def _get_response(self, input_text):
if self.para_model is None:
Expand Down
24 changes: 20 additions & 4 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,17 @@ def _set_hf_context_len(self, config):
self.context_len = config.n_ctx

def _gather_hf_params(self, hf_constructor: Callable):
# this may be a bit too naive as it will pass any parameter valid for the pipeline signature
# this falls over when passed `from_pretrained` methods as the callable model params are not explicit
params = self.hf_args
if params["device"] is None:
""" "Identify arguments that impact huggingface transformers resources and behavior"""

# this may be a bit too naive as it will pass any parameter valid for the hf_constructor signature
# this falls over when passed some `from_pretrained` methods as the callable model params are not always explicit
params = (
self.hf_args
if hasattr(self, "hf_args") and isinstance(self.hf_args, dict)
else {}
)
if params is not None and not "device" in params and hasattr(self, "device"):
# consider setting self.device in all cases or if self.device is not found raise error `_select_hf_device` must be called
params["device"] = self.device

args = {}
Expand Down Expand Up @@ -96,6 +103,15 @@ def _gather_hf_params(self, hf_constructor: Callable):
continue
args[k] = params[k]

if (
not "device_map" in args
and "device_map" in params_to_process
and "device" in params_to_process
and "device" in args
):
del args["device"]
args["device_map"] = self.device

return args

def _select_hf_device(self):
Expand Down

0 comments on commit ee22b5d

Please sign in to comment.