Skip to content

Commit

Permalink
update workflow to treat generators more defensively, casting to list…
Browse files Browse the repository at this point in the history
… if there's a risk of multiple consumption
  • Loading branch information
leondz committed Jul 5, 2024
1 parent c3ad525 commit 3b59576
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 10 deletions.
7 changes: 5 additions & 2 deletions garak/attempt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Defines the Attempt class, which encapsulates a prompt with metadata and results"""

from collections.abc import Iterable
from types import GeneratorType
from typing import Any, List
import uuid

Expand Down Expand Up @@ -179,8 +181,9 @@ def __setattr__(self, name: str, value: Any) -> None:
self._add_first_turn("user", value)

elif name == "outputs":
if not isinstance(value, list):
raise TypeError("Value for attempt.outputs must be a list")
if not (isinstance(value, list) or isinstance(value, GeneratorType)):
raise TypeError("Value for attempt.outputs must be a list or generator")
value = list(value)
if len(self.messages) == 0:
raise TypeError("A prompt must be set before outputs are given")
# do we have only the initial prompt? in which case, let's flesh out messages a bit
Expand Down
10 changes: 7 additions & 3 deletions garak/evaluators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import json
import logging
from typing import List
from typing import Iterable

from colorama import Fore, Style

Expand Down Expand Up @@ -33,19 +33,23 @@ def test(self, test_value: float) -> bool:
"""
return False # fail everything by default

def evaluate(self, attempts: List[garak.attempt.Attempt]) -> None:
def evaluate(self, attempts: Iterable[garak.attempt.Attempt]) -> None:
"""
evaluate feedback from detectors
expects a list of attempts that correspond to one probe
outputs results once per detector
"""

if len(attempts) == 0:
if isinstance(attempts, list) and len(attempts) == 0:
logging.debug(
"evaluators.base.Evaluator.evaluate called with 0 attempts, expected 1+"
)
return

attempts = list(
attempts
) # disprefer this but getting detector_names from first one for the loop below is a pain

self.probename = attempts[0].probe_classname
detector_names = attempts[0].detector_results.keys()

Expand Down
4 changes: 3 additions & 1 deletion garak/harnesses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def run(self, model, probes, detectors, evaluator, announce_probe=True) -> None:
detector_probe_name = d.detectorname.replace("garak.detectors.", "")
attempt_iterator.set_description("detectors." + detector_probe_name)
for attempt in attempt_iterator:
attempt.detector_results[detector_probe_name] = d.detect(attempt)
attempt.detector_results[detector_probe_name] = list(
d.detect(attempt)
)

if first_detector:
eval_outputs += attempt.outputs
Expand Down
17 changes: 13 additions & 4 deletions tests/test_internal_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@
import garak._config
import garak._plugins
import garak.attempt
import garak.evaluators.base
import garak.generators.test

# probes should be able to return a generator of attempts
# -> probes.base.Probe._execute_all (1) should be able to consume a generator of attempts
# generators should be able to return a generator of outputs
# -> attempts (2) should be able to consume a generator of outputs
# -> detectors (3) should be able to consume a generator of outputs
# detectors should be able to return generators of results
# -> evaluators (4) should be able to consume generators of results
# -> attempts (5) should be able to consume generators of detector results
# -> attempt.as_dict (6) should be able to relay a generator of detector results
# -> evaluators (3) should be able to consume generators of results --> enforced in harness; cast to list, multiple consumption



@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -52,3 +51,13 @@ def test_generator_consume_attempt_generator():
assert (
result_len == count
), "there should be the same number of attempts in the passed generator as results returned in _execute_all"

def test_attempt_outputs_can_consume_generator():
a = garak.attempt.Attempt(prompt="fish")
count = 5
str_iter = ("abc" for _ in range(count))
a.outputs = str_iter
outputs_list = list(a.outputs)
assert len(outputs_list) == count, "attempt.outputs should have same cardinality as generator used to populate it"
assert len(list(a.outputs)) == len(outputs_list), "attempt.outputs should have the same cardinality every time"

0 comments on commit 3b59576

Please sign in to comment.