Skip to content

Commit

Permalink
force huggingface to use cpu in test run
Browse files Browse the repository at this point in the history
* cli test with atkgen red_team forced to `cpu`
* huggingface tests force to `cpu`

Signed-off-by: Jeffrey Martin <[email protected]>
  • Loading branch information
jmartin-tech committed Jul 18, 2024
1 parent 1ea2aeb commit 5d62309
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
26 changes: 25 additions & 1 deletion tests/cli/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import re
import pytest
import os
Expand Down Expand Up @@ -57,8 +58,31 @@ def test_buff_list(capsys):


def test_run_all_active_probes(capsys):
# force atkgen red_team generator to use `cpu` in CI/CD
probe_config = {
"atkgen": {
"red_team_model_config": {
"hf_args": {
"device": "cpu",
}
}
}
}

cli.main(
["-m", "test", "-p", "all", "-d", "always.Pass", "-g", "1", "--narrow_output"]
[
"-m",
"test",
"-p",
"all",
"-d",
"always.Pass",
"-g",
"1",
"--narrow_output",
"--probe_options",
json.dumps(probe_config),
]
)
result = capsys.readouterr()
last_line = result.out.strip().split("\n")[-1]
Expand Down
21 changes: 11 additions & 10 deletions tests/generators/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,22 @@
DEFAULT_GENERATIONS_QTY = 10


def test_pipeline():
@pytest.fixture
def hf_generator_config():
gen_config = {
"huggingface": {
"Pipeline": {
"name": "gpt2",
"hf_args": {
"device": "cpu",
},
}
"hf_args": {
"device": "cpu",
},
}
}
config_root = GarakSubConfig()
setattr(config_root, "generators", gen_config)
return config_root

g = garak.generators.huggingface.Pipeline("gpt2", config_root=config_root)

def test_pipeline(hf_generator_config):
g = garak.generators.huggingface.Pipeline("gpt2", config_root=hf_generator_config)
assert g.name == "gpt2"
assert g.generations == DEFAULT_GENERATIONS_QTY
assert isinstance(g.generator, transformers.pipelines.text_generation.Pipeline)
Expand Down Expand Up @@ -53,8 +54,8 @@ def test_inference():
assert isinstance(item, str)


def test_model():
g = garak.generators.huggingface.Model("gpt2")
def test_model(hf_generator_config):
g = garak.generators.huggingface.Model("gpt2", config_root=hf_generator_config)
assert g.name == "gpt2"
assert g.generations == DEFAULT_GENERATIONS_QTY
assert isinstance(g, garak.generators.huggingface.Model)
Expand Down

0 comments on commit 5d62309

Please sign in to comment.