diff --git a/garak/_plugins.py b/garak/_plugins.py index 5cca55033..e805e8745 100644 --- a/garak/_plugins.py +++ b/garak/_plugins.py @@ -9,6 +9,8 @@ import logging import shutil import os +from datetime import datetime, timezone +from threading import Lock from typing import List, Callable, Union from pathlib import Path @@ -43,6 +45,8 @@ class PluginCache: ) _plugin_cache_dict = None + _mutex = Lock() + def __init__(self) -> None: if PluginCache._plugin_cache_dict is None: PluginCache._plugin_cache_dict = self._load_plugin_cache() @@ -56,18 +60,82 @@ def _extract_modules_klasses(base_klass): ] def _load_plugin_cache(self): - if not os.path.exists(self._plugin_cache_filename): - self._build_plugin_cache() + assert os.path.exists( + self._plugin_cache_filename + ), f"{self._plugin_cache_filename} is missing or corrupt." + + base_time = datetime.fromtimestamp( + os.path.getmtime(self._plugin_cache_filename), tz=timezone.utc + ) + update_user_file = False if not os.path.exists(self._user_plugin_cache_filename): + update_user_file = True + else: + user_time = datetime.fromtimestamp( + os.path.getmtime(self._user_plugin_cache_filename), tz=timezone.utc + ) + if base_time > user_time: + update_user_file = True + if update_user_file: self._user_plugin_cache_filename.parent.mkdir( mode=0o740, parents=True, exist_ok=True ) shutil.copy2(self._plugin_cache_filename, self._user_plugin_cache_filename) + user_time = base_time + with open( self._user_plugin_cache_filename, "r", encoding="utf-8" ) as cache_file: local_cache = json.load(cache_file) - return local_cache + + # validate cache state on startup + if not self._valid_loaded_cache(local_cache, user_time): + self._build_plugin_cache() + with open( + self._user_plugin_cache_filename, "r", encoding="utf-8" + ) as cache_file: + local_cache = json.load(cache_file) + + return local_cache + + def _valid_loaded_cache(self, local_cache, user_time): + # verify all known plugins + for plugin_type in PLUGIN_TYPES: + validated_plugin_filenames = set() + prev_mod = None + for k in local_cache[plugin_type]: + category, module, klass = k.split(".") + if module != prev_mod: + prev_mod = module + # TODO: only one possible location for now, expand when `user` plugins are supported + file_path = ( + _config.transient.package_dir / category / f"{module}.py" + ) + if os.path.exists(file_path): + validated_plugin_filenames.add(file_path) + mod_time = datetime.fromtimestamp( + os.path.getmtime(file_path), tz=timezone.utc + ) + else: + # file not found force cache rebuild + mod_time = datetime.now(timezone.utc) + if mod_time > user_time: + # rebuild all for now, this can be more made more selective later + return False + + # if all known are up-to-date check filesystem for missing + found_filenames = set() + plugin_path = _config.transient.package_dir / category + for module_filename in sorted(os.listdir(plugin_path)): + if not module_filename.endswith(".py"): + continue + if module_filename.startswith("__"): + continue + found_filenames.add(plugin_path / module_filename) + if found_filenames != validated_plugin_filenames: + return False + + return True # all checks passed the cache is valid def _build_plugin_cache(self): """build a plugin cache file to improve access times @@ -75,23 +143,25 @@ def _build_plugin_cache(self): This method writes only to the user's cache (currently the same as the system cache) TODO: Enhance location of user cache to enable support for in development plugins. """ - local_cache = {} - - for plugin_type in PLUGIN_TYPES: - plugin_dict = {} - for plugin in self._enumerate_plugin_klasses(plugin_type): - plugin_name = ".".join([plugin.__module__, plugin.__name__]).replace( - "garak.", "" - ) - plugin_dict[plugin_name] = PluginCache.plugin_info(plugin) - - sorted_keys = sorted(list(plugin_dict.keys())) - local_cache[plugin_type] = {i: plugin_dict[i] for i in sorted_keys} - - with open( - self._user_plugin_cache_filename, "w", encoding="utf-8" - ) as cache_file: - json.dump(local_cache, cache_file, cls=PluginEncoder, indent=2) + logging.debug("rebuilding plugin cache") + with PluginCache._mutex: + local_cache = {} + + for plugin_type in PLUGIN_TYPES: + plugin_dict = {} + for plugin in self._enumerate_plugin_klasses(plugin_type): + plugin_name = ".".join( + [plugin.__module__, plugin.__name__] + ).replace("garak.", "") + plugin_dict[plugin_name] = PluginCache.plugin_info(plugin) + + sorted_keys = sorted(list(plugin_dict.keys())) + local_cache[plugin_type] = {i: plugin_dict[i] for i in sorted_keys} + + with open( + self._user_plugin_cache_filename, "w", encoding="utf-8" + ) as cache_file: + json.dump(local_cache, cache_file, cls=PluginEncoder, indent=2) def _enumerate_plugin_klasses(self, category: str) -> List[Callable]: """obtain all""" @@ -136,7 +206,10 @@ def plugin_info(plugin: Union[Callable, str]) -> dict: if category not in PLUGIN_TYPES: raise ValueError(f"Not a recognised plugin type: {category}") - plugin_metadata = PluginCache.instance()[category].get(plugin_name, {}) + cache = PluginCache.instance() + with PluginCache._mutex: + plugin_metadata = cache[category].get(plugin_name, {}) + if len(plugin_metadata) > 0: return plugin_metadata else: @@ -213,8 +286,6 @@ def plugin_info(plugin: Union[Callable, str]) -> dict: logging.error(f"Plugin {plugin_name} not found.") logging.exception(e) - from datetime import datetime, timezone - # adding last class modification time to cache allows for targeted update in future current_mod = importlib.import_module(plugin.__module__) mod_time = datetime.fromtimestamp( diff --git a/tests/plugins/test_plugin_cache.py b/tests/plugins/test_plugin_cache.py index 7124c58f9..2e3baeb7d 100644 --- a/tests/plugins/test_plugin_cache.py +++ b/tests/plugins/test_plugin_cache.py @@ -1,8 +1,11 @@ +import json import pytest import os +import sys import tempfile -from garak._plugins import PluginCache +from pathlib import Path +from garak._plugins import PluginCache, PluginEncoder @pytest.fixture(autouse=True) @@ -17,8 +20,7 @@ def reset_plugin_cache(): def temp_cache_location(request) -> None: # override the cache file with a tmp location with tempfile.NamedTemporaryFile(buffering=0, delete=False) as tmp: - PluginCache._user_plugin_cache_filename = tmp.name - PluginCache._plugin_cache_filename = tmp.name + PluginCache._user_plugin_cache_filename = Path(tmp.name) tmp.close() os.remove(tmp.name) # reset the class level singleton @@ -33,6 +35,25 @@ def remove_cache_file(): return tmp.name +@pytest.fixture +def remove_package_cache(request) -> None: + # override the cache file with a tmp location + original_path = PluginCache._plugin_cache_filename + with tempfile.NamedTemporaryFile(buffering=0, delete=False) as tmp: + PluginCache._plugin_cache_filename = Path(tmp.name) + tmp.close() + os.remove(tmp.name) + # reset the class level singleton + PluginCache._plugin_cache_dict = None + + def restore_package_path(): + PluginCache._plugin_cache_filename = original_path + + request.addfinalizer(restore_package_path) + + return tmp.name + + def test_create(temp_cache_location): cache = PluginCache.instance() assert os.path.isfile(temp_cache_location) @@ -74,3 +95,27 @@ def test_unknown_module(): with pytest.raises(ValueError) as exc_info: info = PluginCache.plugin_info("probes.invalid.format.length") assert "plugin class" in str(exc_info.value) + + +@pytest.mark.skipif( + sys.platform == "win32", + reason="Windows Github executor does not raise the ValueError", +) +def test_module_removed(temp_cache_location): + cache = PluginCache.instance() + cache["probes"]["probes.invalid.Removed"] = { + "description": "Testing value to be purged" + } + with open(temp_cache_location, "w", encoding="utf-8") as cache_file: + json.dump(cache, cache_file, cls=PluginEncoder, indent=2) + PluginCache._plugin_cache_dict = None + with pytest.raises(ValueError) as exc_info: + PluginCache.plugin_info("probes.invalid.Removed") + assert "plugin module" in str(exc_info.value) + + +@pytest.mark.usefixtures("remove_package_cache") +def test_report_missing_package_file(): + with pytest.raises(AssertionError) as exc_info: + PluginCache.instance() + assert "is missing or corrupt" in str(exc_info.value)