diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 2cc80e202..459fb3b03 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -1,3 +1,4 @@ +import json import re import pytest import os @@ -57,8 +58,31 @@ def test_buff_list(capsys): def test_run_all_active_probes(capsys): + # force atkgen red_team generator to use `cpu` in CI/CD + probe_config = { + "atkgen": { + "red_team_model_config": { + "hf_args": { + "device": "cpu", + } + } + } + } + cli.main( - ["-m", "test", "-p", "all", "-d", "always.Pass", "-g", "1", "--narrow_output"] + [ + "-m", + "test", + "-p", + "all", + "-d", + "always.Pass", + "-g", + "1", + "--narrow_output", + "--probe_options", + json.dumps(probe_config), + ] ) result = capsys.readouterr() last_line = result.out.strip().split("\n")[-1] diff --git a/tests/generators/test_huggingface.py b/tests/generators/test_huggingface.py index 6f6d19ec1..904b08970 100644 --- a/tests/generators/test_huggingface.py +++ b/tests/generators/test_huggingface.py @@ -6,21 +6,22 @@ DEFAULT_GENERATIONS_QTY = 10 -def test_pipeline(): +@pytest.fixture +def hf_generator_config(): gen_config = { "huggingface": { - "Pipeline": { - "name": "gpt2", - "hf_args": { - "device": "cpu", - }, - } + "hf_args": { + "device": "cpu", + }, } } config_root = GarakSubConfig() setattr(config_root, "generators", gen_config) + return config_root - g = garak.generators.huggingface.Pipeline("gpt2", config_root=config_root) + +def test_pipeline(hf_generator_config): + g = garak.generators.huggingface.Pipeline("gpt2", config_root=hf_generator_config) assert g.name == "gpt2" assert g.generations == DEFAULT_GENERATIONS_QTY assert isinstance(g.generator, transformers.pipelines.text_generation.Pipeline) @@ -53,8 +54,8 @@ def test_inference(): assert isinstance(item, str) -def test_model(): - g = garak.generators.huggingface.Model("gpt2") +def test_model(hf_generator_config): + g = garak.generators.huggingface.Model("gpt2", config_root=hf_generator_config) assert g.name == "gpt2" assert g.generations == DEFAULT_GENERATIONS_QTY assert isinstance(g, garak.generators.huggingface.Model)