Skip to content

Commit

Permalink
validate cache consistency
Browse files Browse the repository at this point in the history
* validate on startup
* update on any plugin file change

Signed-off-by: Jeffrey Martin <[email protected]>
  • Loading branch information
jmartin-tech committed Aug 1, 2024
1 parent 4aa85ac commit 3ead7c8
Showing 1 changed file with 91 additions and 21 deletions.
112 changes: 91 additions & 21 deletions garak/_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import logging
import shutil
import os
from datetime import datetime, timezone
from threading import Lock
from typing import List, Callable, Union
from pathlib import Path

Expand Down Expand Up @@ -43,6 +45,8 @@ class PluginCache:
)
_plugin_cache_dict = None

_mutex = Lock()

def __init__(self) -> None:
if PluginCache._plugin_cache_dict is None:
PluginCache._plugin_cache_dict = self._load_plugin_cache()
Expand All @@ -58,40 +62,105 @@ def _extract_modules_klasses(base_klass):
def _load_plugin_cache(self):
if not os.path.exists(self._plugin_cache_filename):
self._build_plugin_cache()

base_time = datetime.fromtimestamp(
os.path.getmtime(self._plugin_cache_filename), tz=timezone.utc
)
update_user_file = False
if not os.path.exists(self._user_plugin_cache_filename):
update_user_file = True
else:
user_time = datetime.fromtimestamp(
os.path.getmtime(self._user_plugin_cache_filename), tz=timezone.utc
)
if base_time > user_time:
update_user_file = True
if update_user_file:
self._user_plugin_cache_filename.parent.mkdir(
mode=0o740, parents=True, exist_ok=True
)
shutil.copy2(self._plugin_cache_filename, self._user_plugin_cache_filename)
user_time = base_time

with open(
self._user_plugin_cache_filename, "r", encoding="utf-8"
) as cache_file:
local_cache = json.load(cache_file)
return local_cache

# validate cache state on startup
if not self._valid_loaded_cache(local_cache, user_time):
self._build_plugin_cache()
with open(
self._user_plugin_cache_filename, "r", encoding="utf-8"
) as cache_file:
local_cache = json.load(cache_file)

return local_cache

def _valid_loaded_cache(self, local_cache, user_time):
# verify all known plugins
for plugin_type in PLUGIN_TYPES:
validated_plugin_filenames = set()
prev_mod = None
for k in local_cache[plugin_type]:
category, module, klass = k.split(".")
if module != prev_mod:
prev_mod = module
# TODO: only one possible location for now, expand when `user` plugins are supported
file_path = (
_config.transient.package_dir / category / f"{module}.py"
)
if os.path.exists(file_path):
validated_plugin_filenames.add(file_path)
mod_time = datetime.fromtimestamp(
os.path.getmtime(file_path), tz=timezone.utc
)
else:
# file not found force cache rebuild
mod_time = datetime.now(timezone.utc)
if mod_time > user_time:
# rebuild all for now, this can be more made more selective later
return False

# if all known are up-to-date check filesystem for missing
found_filenames = set()
plugin_path = _config.transient.package_dir / category
for module_filename in sorted(os.listdir(plugin_path)):
if not module_filename.endswith(".py"):
continue
if module_filename.startswith("__"):
continue
found_filenames.add(plugin_path / module_filename)
if found_filenames != validated_plugin_filenames:
return False

return True # all checks passed the cache is valid

def _build_plugin_cache(self):
"""build a plugin cache file to improve access times
This method writes only to the user's cache (currently the same as the system cache)
TODO: Enhance location of user cache to enable support for in development plugins.
"""
local_cache = {}

for plugin_type in PLUGIN_TYPES:
plugin_dict = {}
for plugin in self._enumerate_plugin_klasses(plugin_type):
plugin_name = ".".join([plugin.__module__, plugin.__name__]).replace(
"garak.", ""
)
plugin_dict[plugin_name] = PluginCache.plugin_info(plugin)

sorted_keys = sorted(list(plugin_dict.keys()))
local_cache[plugin_type] = {i: plugin_dict[i] for i in sorted_keys}

with open(
self._user_plugin_cache_filename, "w", encoding="utf-8"
) as cache_file:
json.dump(local_cache, cache_file, cls=PluginEncoder, indent=2)
logging.debug("rebuilding plugin cache")
with PluginCache._mutex:
local_cache = {}

for plugin_type in PLUGIN_TYPES:
plugin_dict = {}
for plugin in self._enumerate_plugin_klasses(plugin_type):
plugin_name = ".".join(
[plugin.__module__, plugin.__name__]
).replace("garak.", "")
plugin_dict[plugin_name] = PluginCache.plugin_info(plugin)

sorted_keys = sorted(list(plugin_dict.keys()))
local_cache[plugin_type] = {i: plugin_dict[i] for i in sorted_keys}

with open(
self._user_plugin_cache_filename, "w", encoding="utf-8"
) as cache_file:
json.dump(local_cache, cache_file, cls=PluginEncoder, indent=2)

def _enumerate_plugin_klasses(self, category: str) -> List[Callable]:
"""obtain all"""
Expand Down Expand Up @@ -136,7 +205,10 @@ def plugin_info(plugin: Union[Callable, str]) -> dict:
if category not in PLUGIN_TYPES:
raise ValueError(f"Not a recognised plugin type: {category}")

plugin_metadata = PluginCache.instance()[category].get(plugin_name, {})
cache = PluginCache.instance()
with PluginCache._mutex:
plugin_metadata = cache[category].get(plugin_name, {})

if len(plugin_metadata) > 0:
return plugin_metadata
else:
Expand Down Expand Up @@ -213,8 +285,6 @@ def plugin_info(plugin: Union[Callable, str]) -> dict:
logging.error(f"Plugin {plugin_name} not found.")
logging.exception(e)

from datetime import datetime, timezone

# adding last class modification time to cache allows for targeted update in future
current_mod = importlib.import_module(plugin.__module__)
mod_time = datetime.fromtimestamp(
Expand Down

0 comments on commit 3ead7c8

Please sign in to comment.