diff --git a/garak/probes/atkgen.py b/garak/probes/atkgen.py index 0f09c7660..0fb2aa173 100644 --- a/garak/probes/atkgen.py +++ b/garak/probes/atkgen.py @@ -57,7 +57,7 @@ class Tox(Probe): "red_team_model_type": "huggingface.Pipeline", "red_team_model_name": "leondz/artgpt2tox", "red_team_model_config": { - "hf_args": {"device": "cpu"} + "hf_args": {"device": "cpu", "torch_dtype": "float32"} }, # defer acceleration devices to model under test unless overriden "red_team_prompt_template": "<|input|>[query]<|response|>", "red_team_postproc_rm_regex": "\<\|.*", diff --git a/tests/generators/test_huggingface.py b/tests/generators/test_huggingface.py index 904b08970..cce1d39ff 100644 --- a/tests/generators/test_huggingface.py +++ b/tests/generators/test_huggingface.py @@ -12,6 +12,7 @@ def hf_generator_config(): "huggingface": { "hf_args": { "device": "cpu", + "torch_dtype": "float32", }, } } diff --git a/tests/plugins/test_plugin_cache.py b/tests/plugins/test_plugin_cache.py index 5fc3ada15..50ca8ea75 100644 --- a/tests/plugins/test_plugin_cache.py +++ b/tests/plugins/test_plugin_cache.py @@ -11,6 +11,7 @@ def temp_cache_location(request) -> None: with tempfile.NamedTemporaryFile(buffering=0, delete=False) as tmp: PluginCache._user_plugin_cache_file = tmp.name PluginCache._plugin_cache_file = tmp.name + tmp.close() os.remove(tmp.name) # reset the class level singleton PluginCache._plugin_cache_dict = None diff --git a/tests/probes/test_probes_atkgen.py b/tests/probes/test_probes_atkgen.py index 150f00f37..b8c0b2986 100644 --- a/tests/probes/test_probes_atkgen.py +++ b/tests/probes/test_probes_atkgen.py @@ -27,9 +27,7 @@ def test_atkgen_config(): "generators": { rt_mod: { rt_klass: { - "hf_args": { - "device": "cpu", - }, + "hf_args": {"device": "cpu", "torch_dtype": "float32"}, "name": p.red_team_model_name, } }