Skip to content

Commit

Permalink
validate cache consistency on first access (#815)
Browse files Browse the repository at this point in the history
* validate cache consistency

* validate on startup
* update on any plugin file change

Signed-off-by: Jeffrey Martin <[email protected]>

* assert when package distributed cache missing

* raises assertion when package plugin cache is not found
* test for removed module file
* test missing package plugin cache

Signed-off-by: Jeffrey Martin <[email protected]>

---------

Signed-off-by: Jeffrey Martin <[email protected]>
  • Loading branch information
jmartin-tech authored Aug 1, 2024
1 parent 466ea05 commit a78c2f8
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 26 deletions.
117 changes: 94 additions & 23 deletions garak/_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import logging
import shutil
import os
from datetime import datetime, timezone
from threading import Lock
from typing import List, Callable, Union
from pathlib import Path

Expand Down Expand Up @@ -43,6 +45,8 @@ class PluginCache:
)
_plugin_cache_dict = None

_mutex = Lock()

def __init__(self) -> None:
if PluginCache._plugin_cache_dict is None:
PluginCache._plugin_cache_dict = self._load_plugin_cache()
Expand All @@ -56,42 +60,108 @@ def _extract_modules_klasses(base_klass):
]

def _load_plugin_cache(self):
if not os.path.exists(self._plugin_cache_filename):
self._build_plugin_cache()
assert os.path.exists(
self._plugin_cache_filename
), f"{self._plugin_cache_filename} is missing or corrupt."

base_time = datetime.fromtimestamp(
os.path.getmtime(self._plugin_cache_filename), tz=timezone.utc
)
update_user_file = False
if not os.path.exists(self._user_plugin_cache_filename):
update_user_file = True
else:
user_time = datetime.fromtimestamp(
os.path.getmtime(self._user_plugin_cache_filename), tz=timezone.utc
)
if base_time > user_time:
update_user_file = True
if update_user_file:
self._user_plugin_cache_filename.parent.mkdir(
mode=0o740, parents=True, exist_ok=True
)
shutil.copy2(self._plugin_cache_filename, self._user_plugin_cache_filename)
user_time = base_time

with open(
self._user_plugin_cache_filename, "r", encoding="utf-8"
) as cache_file:
local_cache = json.load(cache_file)
return local_cache

# validate cache state on startup
if not self._valid_loaded_cache(local_cache, user_time):
self._build_plugin_cache()
with open(
self._user_plugin_cache_filename, "r", encoding="utf-8"
) as cache_file:
local_cache = json.load(cache_file)

return local_cache

def _valid_loaded_cache(self, local_cache, user_time):
# verify all known plugins
for plugin_type in PLUGIN_TYPES:
validated_plugin_filenames = set()
prev_mod = None
for k in local_cache[plugin_type]:
category, module, klass = k.split(".")
if module != prev_mod:
prev_mod = module
# TODO: only one possible location for now, expand when `user` plugins are supported
file_path = (
_config.transient.package_dir / category / f"{module}.py"
)
if os.path.exists(file_path):
validated_plugin_filenames.add(file_path)
mod_time = datetime.fromtimestamp(
os.path.getmtime(file_path), tz=timezone.utc
)
else:
# file not found force cache rebuild
mod_time = datetime.now(timezone.utc)
if mod_time > user_time:
# rebuild all for now, this can be more made more selective later
return False

# if all known are up-to-date check filesystem for missing
found_filenames = set()
plugin_path = _config.transient.package_dir / category
for module_filename in sorted(os.listdir(plugin_path)):
if not module_filename.endswith(".py"):
continue
if module_filename.startswith("__"):
continue
found_filenames.add(plugin_path / module_filename)
if found_filenames != validated_plugin_filenames:
return False

return True # all checks passed the cache is valid

def _build_plugin_cache(self):
"""build a plugin cache file to improve access times
This method writes only to the user's cache (currently the same as the system cache)
TODO: Enhance location of user cache to enable support for in development plugins.
"""
local_cache = {}

for plugin_type in PLUGIN_TYPES:
plugin_dict = {}
for plugin in self._enumerate_plugin_klasses(plugin_type):
plugin_name = ".".join([plugin.__module__, plugin.__name__]).replace(
"garak.", ""
)
plugin_dict[plugin_name] = PluginCache.plugin_info(plugin)

sorted_keys = sorted(list(plugin_dict.keys()))
local_cache[plugin_type] = {i: plugin_dict[i] for i in sorted_keys}

with open(
self._user_plugin_cache_filename, "w", encoding="utf-8"
) as cache_file:
json.dump(local_cache, cache_file, cls=PluginEncoder, indent=2)
logging.debug("rebuilding plugin cache")
with PluginCache._mutex:
local_cache = {}

for plugin_type in PLUGIN_TYPES:
plugin_dict = {}
for plugin in self._enumerate_plugin_klasses(plugin_type):
plugin_name = ".".join(
[plugin.__module__, plugin.__name__]
).replace("garak.", "")
plugin_dict[plugin_name] = PluginCache.plugin_info(plugin)

sorted_keys = sorted(list(plugin_dict.keys()))
local_cache[plugin_type] = {i: plugin_dict[i] for i in sorted_keys}

with open(
self._user_plugin_cache_filename, "w", encoding="utf-8"
) as cache_file:
json.dump(local_cache, cache_file, cls=PluginEncoder, indent=2)

def _enumerate_plugin_klasses(self, category: str) -> List[Callable]:
"""obtain all"""
Expand Down Expand Up @@ -136,7 +206,10 @@ def plugin_info(plugin: Union[Callable, str]) -> dict:
if category not in PLUGIN_TYPES:
raise ValueError(f"Not a recognised plugin type: {category}")

plugin_metadata = PluginCache.instance()[category].get(plugin_name, {})
cache = PluginCache.instance()
with PluginCache._mutex:
plugin_metadata = cache[category].get(plugin_name, {})

if len(plugin_metadata) > 0:
return plugin_metadata
else:
Expand Down Expand Up @@ -213,8 +286,6 @@ def plugin_info(plugin: Union[Callable, str]) -> dict:
logging.error(f"Plugin {plugin_name} not found.")
logging.exception(e)

from datetime import datetime, timezone

# adding last class modification time to cache allows for targeted update in future
current_mod = importlib.import_module(plugin.__module__)
mod_time = datetime.fromtimestamp(
Expand Down
51 changes: 48 additions & 3 deletions tests/plugins/test_plugin_cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import pytest
import os
import sys
import tempfile

from garak._plugins import PluginCache
from pathlib import Path
from garak._plugins import PluginCache, PluginEncoder


@pytest.fixture(autouse=True)
Expand All @@ -17,8 +20,7 @@ def reset_plugin_cache():
def temp_cache_location(request) -> None:
# override the cache file with a tmp location
with tempfile.NamedTemporaryFile(buffering=0, delete=False) as tmp:
PluginCache._user_plugin_cache_filename = tmp.name
PluginCache._plugin_cache_filename = tmp.name
PluginCache._user_plugin_cache_filename = Path(tmp.name)
tmp.close()
os.remove(tmp.name)
# reset the class level singleton
Expand All @@ -33,6 +35,25 @@ def remove_cache_file():
return tmp.name


@pytest.fixture
def remove_package_cache(request) -> None:
# override the cache file with a tmp location
original_path = PluginCache._plugin_cache_filename
with tempfile.NamedTemporaryFile(buffering=0, delete=False) as tmp:
PluginCache._plugin_cache_filename = Path(tmp.name)
tmp.close()
os.remove(tmp.name)
# reset the class level singleton
PluginCache._plugin_cache_dict = None

def restore_package_path():
PluginCache._plugin_cache_filename = original_path

request.addfinalizer(restore_package_path)

return tmp.name


def test_create(temp_cache_location):
cache = PluginCache.instance()
assert os.path.isfile(temp_cache_location)
Expand Down Expand Up @@ -74,3 +95,27 @@ def test_unknown_module():
with pytest.raises(ValueError) as exc_info:
info = PluginCache.plugin_info("probes.invalid.format.length")
assert "plugin class" in str(exc_info.value)


@pytest.mark.skipif(
sys.platform == "win32",
reason="Windows Github executor does not raise the ValueError",
)
def test_module_removed(temp_cache_location):
cache = PluginCache.instance()
cache["probes"]["probes.invalid.Removed"] = {
"description": "Testing value to be purged"
}
with open(temp_cache_location, "w", encoding="utf-8") as cache_file:
json.dump(cache, cache_file, cls=PluginEncoder, indent=2)
PluginCache._plugin_cache_dict = None
with pytest.raises(ValueError) as exc_info:
PluginCache.plugin_info("probes.invalid.Removed")
assert "plugin module" in str(exc_info.value)


@pytest.mark.usefixtures("remove_package_cache")
def test_report_missing_package_file():
with pytest.raises(AssertionError) as exc_info:
PluginCache.instance()
assert "is missing or corrupt" in str(exc_info.value)

0 comments on commit a78c2f8

Please sign in to comment.