Skip to content

Commit

Permalink
generators: add option to specify HTTP codes to skip generation on, f…
Browse files Browse the repository at this point in the history
…or `RestGenerator` (#999)

Some prompts are difficult for some endpoints to respond to. This patch
offers a way for garak to skip those prompts, if the generation failure
is expressed in an endpoint HTTP response code (e.g. `400`)

## Verification
- [ ] Run the tests and ensure they pass `python -m pytest tests/`
- [ ] Run a restgenerator against an endpoint that returns 400; there
should be a `None` returned by the generator and the probe should
present as SKIP on the CLI
- [ ] Docs added in PR

## Context
Targets issue raised in
https://discord.com/channels/1121536128269422654/1121536129099907190/1305844548181692416
  • Loading branch information
leondz authored Nov 14, 2024
2 parents fa2d57c + effa943 commit 96a099d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/source/garak.generators.rest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Uses the following options from ``_config.plugins.generators["rest.RestGenerator
* ``response_json_field`` - (optional) Which field of the response JSON should be used as the output string? Default ``text``. Can also be a JSONPath value, and ``response_json_field`` is used as such if it starts with ``$``.
* ``request_timeout`` - How many seconds should we wait before timing out? Default 20
* ``ratelimit_codes`` - Which endpoint HTTP response codes should be caught as indicative of rate limiting and retried? ``List[int]``, default ``[429]``
* ``skip_codes`` - Which endpoint HTTP response code should lead to the generation being treated as not possible and skipped for this query. Takes precedence over ``ratelimit_codes``.

Templates can be either a string or a JSON-serialisable Python object.
Instance of ``$INPUT`` here are replaced with the prompt; instances of ``$KEY``
Expand Down
29 changes: 22 additions & 7 deletions garak/generators/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class RestGenerator(Generator):
"headers": {},
"method": "post",
"ratelimit_codes": [429],
"skip_codes": [],
"response_json": False,
"response_json_field": None,
"req_template": "$INPUT",
Expand All @@ -55,6 +56,7 @@ class RestGenerator(Generator):
"req_template_json_object",
"request_timeout",
"ratelimit_codes",
"skip_codes",
"temperature",
"top_k",
)
Expand Down Expand Up @@ -121,7 +123,7 @@ def __init__(self, uri=None, config_root=_config):
try:
self.json_expr = jsonpath_ng.parse(self.response_json_field)
except JsonPathParserError as e:
logging.CRITICAL(
logging.critical(
"Couldn't parse response_json_field %s", self.response_json_field
)
raise e
Expand Down Expand Up @@ -193,31 +195,44 @@ def _call_model(
"timeout": self.request_timeout,
}
resp = self.http_function(self.uri, **req_kArgs)

if resp.status_code in self.skip_codes:
logging.debug(
"REST skip prompt: %s - %s, uri: %s",
resp.status_code,
resp.reason,
self.uri,
)
return [None]

if resp.status_code in self.ratelimit_codes:
raise RateLimitHit(f"Rate limited: {resp.status_code} - {resp.reason}, uri: {self.uri}")
raise RateLimitHit(
f"Rate limited: {resp.status_code} - {resp.reason}, uri: {self.uri}"
)

elif str(resp.status_code)[0] == "3":
if str(resp.status_code)[0] == "3":
raise NotImplementedError(
f"REST URI redirection: {resp.status_code} - {resp.reason}, uri: {self.uri}"
)

elif str(resp.status_code)[0] == "4":
if str(resp.status_code)[0] == "4":
raise ConnectionError(
f"REST URI client error: {resp.status_code} - {resp.reason}, uri: {self.uri}"
)

elif str(resp.status_code)[0] == "5":
if str(resp.status_code)[0] == "5":
error_msg = f"REST URI server error: {resp.status_code} - {resp.reason}, uri: {self.uri}"
if self.retry_5xx:
raise IOError(error_msg)
else:
raise ConnectionError(error_msg)
raise ConnectionError(error_msg)

if not self.response_json:
return [str(resp.text)]

response_object = json.loads(resp.content)

response = [None]

# if response_json_field starts with a $, treat is as a JSONPath
assert (
self.response_json
Expand Down
28 changes: 27 additions & 1 deletion tests/generators/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import requests_mock
from sympy import is_increasing

from garak import _config
from garak import _config, _plugins

from garak.generators.rest import RestGenerator

Expand Down Expand Up @@ -95,3 +95,29 @@ def test_json_rest_deeper(requests_mock):
generator = RestGenerator()
output = generator._call_model("Who is Enabran Tain's son?")
assert output == [DEFAULT_TEXT_RESPONSE]


@pytest.mark.usefixtures("set_rest_config")
def test_rest_skip_code(requests_mock):
generator = _plugins.load_plugin(
"generators.rest.RestGenerator", config_root=_config
)
generator.skip_codes = [200]
requests_mock.post(
DEFAULT_URI,
text=json.dumps(
{
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": DEFAULT_TEXT_RESPONSE,
},
}
]
}
),
)
output = generator._call_model("Who is Enabran Tain's son?")
assert output == [None]

0 comments on commit 96a099d

Please sign in to comment.