diff --git a/docs/source/configurable.rst b/docs/source/configurable.rst index 69294ef62..35ed8fba4 100644 --- a/docs/source/configurable.rst +++ b/docs/source/configurable.rst @@ -101,6 +101,7 @@ such as ``show_100_pass_modules``. * ``deprefix`` - Remove the prompt from the start of the output (some models return the prompt as part of their output) * ``seed`` - An optional random seed * ``eval_threshold`` - At what point in the 0..1 range output by detectors does a result count as a successful attack / hit +* ``user_agent`` - What HTTP user agent string should garak use? ``{version}`` can be used to signify where garak version ID should go ``plugins`` config items """""""""""""""""""""""" diff --git a/garak/_config.py b/garak/_config.py index f420d5484..7e3480ae3 100644 --- a/garak/_config.py +++ b/garak/_config.py @@ -23,7 +23,7 @@ DICT_CONFIG_AFTER_LOAD = False -version = -1 # eh why this is here? hm. who references it +from garak import __version__ as version system_params = ( "verbose narrow_output parallel_requests parallel_attempts skip_unknown".split() @@ -144,14 +144,66 @@ def _load_yaml_config(settings_filenames) -> dict: def _store_config(settings_files) -> None: - global system, run, plugins, reporting + global system, run, plugins, reporting, version settings = _load_yaml_config(settings_files) system = _set_settings(system, settings["system"]) run = _set_settings(run, settings["run"]) + run.user_agent = run.user_agent.replace("{version}", version) plugins = _set_settings(plugins, settings["plugins"]) reporting = _set_settings(reporting, settings["reporting"]) +# not my favourite solution in this module, but if +# _config.set_http_lib_agents() to be predicated on a param instead of +# a _config.run value (i.e. user_agent) - which it needs to be if it can be +# used when the values are popped back to originals - then a separate way +# of passing the UA string to _garak_user_agent() needs to exist, outside of +# _config.run.user_agent +REQUESTS_AGENT = "" + + +def _garak_user_agent(dummy=None): + return str(REQUESTS_AGENT) + + +def set_all_http_lib_agents(agent_string): + set_http_lib_agents( + {"requests": agent_string, "httpx": agent_string, "aiohttp": agent_string} + ) + + +def set_http_lib_agents(agent_strings: dict): + + global REQUESTS_AGENT + + if "requests" in agent_strings: + from requests import utils + + REQUESTS_AGENT = agent_strings["requests"] + utils.default_user_agent = _garak_user_agent + if "httpx" in agent_strings: + import httpx + + httpx._client.USER_AGENT = agent_strings["httpx"] + if "aiohttp" in agent_strings: + import aiohttp + + aiohttp.client_reqrep.SERVER_SOFTWARE = agent_strings["aiohttp"] + + +def get_http_lib_agents(): + from requests import utils + import httpx + import aiohttp + + agent_strings = {} + agent_strings["requests"] = utils.default_user_agent + agent_strings["httpx"] = httpx._client.USER_AGENT + agent_strings["aiohttp"] = aiohttp.client_reqrep.SERVER_SOFTWARE + + return agent_strings + + def load_base_config() -> None: global loaded settings_files = [str(transient.package_dir / "resources" / "garak.core.yaml")] @@ -193,6 +245,7 @@ def load_config( logging.debug("Loading configs from: %s", ",".join(settings_files)) _store_config(settings_files=settings_files) + if DICT_CONFIG_AFTER_LOAD: _lock_config_as_dict() loaded = True diff --git a/garak/cli.py b/garak/cli.py index 5894ce0dd..02d0b9833 100644 --- a/garak/cli.py +++ b/garak/cli.py @@ -10,13 +10,12 @@ def main(arguments=None) -> None: """Main entry point for garak runs invoked from the CLI""" import datetime - from garak import __version__, __description__ + from garak import __description__ from garak import _config from garak.exception import GarakException _config.transient.starttime = datetime.datetime.now() _config.transient.starttime_iso = _config.transient.starttime.isoformat() - _config.version = __version__ if arguments is None: arguments = [] @@ -255,6 +254,7 @@ def main(arguments=None) -> None: # load site config before loading CLI config _cli_config_supplied = args.config is not None + prior_user_agents = _config.get_http_lib_agents() _config.load_config(run_config_filename=args.config) # extract what was actually passed on CLI; use a masking argparser @@ -554,3 +554,5 @@ def main(arguments=None) -> None: except (ValueError, GarakException) as e: logging.exception(e) print(e) + + _config.set_http_lib_agents(prior_user_agents) diff --git a/garak/harnesses/base.py b/garak/harnesses/base.py index 79e9c63a3..898684bf6 100644 --- a/garak/harnesses/base.py +++ b/garak/harnesses/base.py @@ -64,6 +64,13 @@ def _load_buffs(self, buff_names: List) -> None: logging.warning(err_msg) continue + def _start_run_hook(self): + self._http_lib_user_agents = _config.get_http_lib_agents() + _config.set_all_http_lib_agents(_config.run.user_agent) + + def _end_run_hook(self): + _config.set_http_lib_agents(self._http_lib_user_agents) + def run(self, model, probes, detectors, evaluator, announce_probe=True) -> None: """Core harness method @@ -92,6 +99,8 @@ def run(self, model, probes, detectors, evaluator, announce_probe=True) -> None: print(msg) raise ValueError(msg) + self._start_run_hook() + for probe in probes: logging.debug("harness: probe start for %s", probe.probename) if not probe: @@ -135,4 +144,6 @@ def run(self, model, probes, detectors, evaluator, announce_probe=True) -> None: else: evaluator.evaluate(attempt_results) + self._end_run_hook() + logging.debug("harness: probe list iteration completed") diff --git a/garak/resources/garak.core.yaml b/garak/resources/garak.core.yaml index 98a1fa2e9..a3d948935 100644 --- a/garak/resources/garak.core.yaml +++ b/garak/resources/garak.core.yaml @@ -13,6 +13,7 @@ run: eval_threshold: 0.5 generations: 5 probe_tags: + user_agent: "garak/{version} (LLM vulnerability scanner https://garak.ai)" plugins: model_type: diff --git a/pyproject.toml b/pyproject.toml index 681ffc72b..e7d0ca3dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,8 @@ tests = [ "pytest>=8.0", "requests-mock==1.12.1", "respx>=0.21.1", - "pytest-cov>=5.0.0" + "pytest-cov>=5.0.0", + "pytest_httpserver>=1.1.0" ] lint = [ "black==24.4.2", diff --git a/requirements.txt b/requirements.txt index 8eb5a3ee0..8b30110a1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,6 +41,7 @@ pytest>=8.0 requests-mock==1.12.1 respx>=0.21.1 pytest-cov>=5.0.0 +pytest_httpserver>=1.1.0 # lint black==24.4.2 pylint>=3.1.0 diff --git a/tests/generators/test_rest.py b/tests/generators/test_rest.py index 932473ba8..1fcf67175 100644 --- a/tests/generators/test_rest.py +++ b/tests/generators/test_rest.py @@ -14,6 +14,7 @@ @pytest.fixture def set_rest_config(): + _config.run.user_agent = "test user agent, garak.ai" _config.plugins.generators["rest"] = {} _config.plugins.generators["rest"]["RestGenerator"] = { "name": DEFAULT_NAME, diff --git a/tests/test_config.py b/tests/test_config.py index 3892e6774..60d675562 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,14 +4,15 @@ import importlib import json import os +from pathlib import Path +import pytest import re import shutil import sys import tempfile -import pytest +from pytest_httpserver import HTTPServer -from pathlib import Path from garak import _config import garak.cli @@ -764,3 +765,63 @@ def test_nested(): _config.plugins.generators["a"]["b"]["c"]["d"] = "e" assert _config.plugins.generators["a"]["b"]["c"]["d"] == "e" + + +def test_get_user_agents(): + agents = _config.get_http_lib_agents() + assert isinstance(agents, dict) + + +AGENT_TEST = "garak/9 - only simple tailors edition" + + +def test_set_agents(): + from requests import utils + import httpx + import aiohttp + + _config.set_all_http_lib_agents(AGENT_TEST) + + assert str(utils.default_user_agent()) == AGENT_TEST + assert httpx._client.USER_AGENT == AGENT_TEST + assert aiohttp.client_reqrep.SERVER_SOFTWARE == AGENT_TEST + +def httpserver(): + return HTTPServer() + + +def test_agent_is_used_requests(httpserver: HTTPServer): + import requests + + _config.set_http_lib_agents({"requests": AGENT_TEST}) + httpserver.expect_request( + "/", headers={"User-Agent": AGENT_TEST} + ).respond_with_data("") + assert requests.get(httpserver.url_for("/")).status_code == 200 + + +def test_agent_is_used_httpx(httpserver: HTTPServer): + import httpx + + _config.set_http_lib_agents({"httpx": AGENT_TEST}) + httpserver.expect_request( + "/", headers={"User-Agent": AGENT_TEST} + ).respond_with_data("") + assert httpx.get(httpserver.url_for("/")).status_code == 200 + + +def test_agent_is_used_aiohttp(httpserver: HTTPServer): + import aiohttp + import asyncio + + _config.set_http_lib_agents({"aiohttp": AGENT_TEST}) + + async def main(): + async with aiohttp.ClientSession() as session: + async with session.get(httpserver.url_for("/")) as response: + html = await response.text() + + httpserver.expect_request( + "/", headers={"User-Agent": AGENT_TEST} + ).respond_with_data("") + asyncio.run(main())