Skip to content

Commit

Permalink
change behavior wrt read-only caches
Browse files Browse the repository at this point in the history
  • Loading branch information
benvanwerkhoven committed Jun 20, 2024
1 parent 3c49578 commit 9468339
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
38 changes: 21 additions & 17 deletions kernel_tuner/cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def create(
if len(cls.RESERVED_PARAM_KEYS & set(tune_params_keys)) > 0:
raise ValueError("Found a reserved key in tune_params_keys")

# main dictionary for new cache, note it is very important that 'cache' is the last key in the dict
cache_json: CacheFileJSON = {
"schema_version": str(LATEST_VERSION),
"device_name": device_name,
Expand All @@ -120,10 +121,10 @@ def create(
"objective": objective,
"cache": {},
}
# NOTE: Validate the cache just to be sure

cls.validate_json(cache_json)
write_cache(cast(dict, cache_json), filename)
return cls(filename, cache_json, readonly=False)
return cls(filename, cache_json, read_only=False)

@classmethod
def open(cls, filename: PathLike):
Expand All @@ -134,24 +135,27 @@ def open(cls, filename: PathLike):
cache_json = read_cache(filename)
assert Version.parse(cache_json["schema_version"]) == LATEST_VERSION, "Cache file is not of the latest version."
cls.validate_json(cache_json)
return cls(filename, cache_json, readonly=False)
return cls(filename, cache_json, read_only=False)

@classmethod
def read(cls, filename: PathLike):
"""Loads an existing cache. Returns a Cache instance which can only be read.
def read(cls, filename: PathLike, read_only=False):
"""Loads an existing cache. Returns a Cache instance.
If the cache file does not have the latest version, then it will be read after virtually converting it to the
latest version. The file in this case is kept the same.
"""
cache_json = read_cache(filename)
# If the cache is versioned, then validate it
if "schema_version" in cache_json:
cls.validate_json(cache_json)

cache_json = convert_cache(cache_json)
# NOTE: Validate the cache just to be sure
# convert cache to latest schema if needed, then validate
if "schema_version" not in cache_json or cache_json["schema_version"] != LATEST_VERSION:
cache_json = convert_cache(cache_json)
# if not read-only mode, update the file
if not read_only:
write_cache(cast(dict, cache_json), filename)

cls.validate_json(cache_json)
return cls(filename, cache_json, readonly=True)

return cls(filename, cache_json, read_only=read_only)

@classmethod
def validate(cls, filename: PathLike):
Expand Down Expand Up @@ -185,14 +189,14 @@ def __get_schema_for_version(cls, version: str):
with open(schema_path, "r") as file:
return json.load(file)

def __init__(self, filename: PathLike, cache_json: CacheFileJSON, *, readonly: bool):
def __init__(self, filename: PathLike, cache_json: CacheFileJSON, *, read_only: bool):
"""Inits a cache file instance, given that the file referred to by ``filename`` contains data ``cache_json``.
Argument ``cache_json`` is a cache dictionary expected to have the latest cache version.
"""
self._filename = Path(filename)
self._cache_json = cache_json
self._readonly = readonly
self._read_only = read_only

@cached_property
def filepath(self) -> Path:
Expand All @@ -207,7 +211,7 @@ def version(self) -> Version:
@cached_property
def lines(self) -> Union[Lines, ReadableLines]:
"""List of cache lines."""
if self._readonly:
if self._read_only:
return self.ReadableLines(self, self._filename, self._cache_json)
else:
return self.Lines(self, self._filename, self._cache_json)
Expand Down Expand Up @@ -422,11 +426,11 @@ def __get_line_json_object(
return line

class ReadableLines(Lines):
"""Cache lines in a readonly cache file."""
"""Cache lines in a read_only cache file."""

def append(*args, **kwargs):
"""Dummy method that does nothing."""
pass
""" Method to append lines to cache file, should not happen with read-only cache """
raise ValueError(f"Attempting to write to read-only cache")

class Line(Mapping):
"""Cache line in a cache file.
Expand Down
3 changes: 2 additions & 1 deletion kernel_tuner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,8 @@ def process_cache(cache, kernel_options, tuning_options, runner):

# if file exists
else:
c = Cache.read(cache)
# Read existing cache file, when using simulation_mode file is treated read_only
c = Cache.read(cache, read_only=tuning_options.simulation_mode)

# if in simulation mode, use the device name from the cache file as the runner device name
if runner.simulation_mode:
Expand Down

0 comments on commit 9468339

Please sign in to comment.