From 145004ec8cf3fa84e86faf4c0699f2104c7365f1 Mon Sep 17 00:00:00 2001 From: Nicolas Tessore Date: Tue, 10 Oct 2023 22:49:21 +0100 Subject: [PATCH] fixes and weakref caching --- heracles/io.py | 26 +++++++++++++++++++------- tests/test_io.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/heracles/io.py b/heracles/io.py index eeb6e5c..88127f2 100644 --- a/heracles/io.py +++ b/heracles/io.py @@ -21,8 +21,10 @@ import logging import os from collections.abc import MutableMapping +from functools import partial from pathlib import Path from types import MappingProxyType +from weakref import WeakValueDictionary import fitsio import healpy as hp @@ -782,18 +784,28 @@ def __init__(self, path, *, clobber=False): # store the ToC as a mapping self._toc = {tuple(key): str(ext) for ext, *key in toc} + # set up a weakly-referenced cache for extension data + self._cache = WeakValueDictionary() + def __len__(self): return len(self._toc) def __iter__(self): return iter(self._toc) + def __contains__(self, key): + return key in self._toc + def __getitem__(self, key): if not isinstance(key, tuple): key = (key,) - ext = self._toc[key] - with self.fits as fits: - return self.reader(fits, ext) + value = self._cache.get(key) + if value is None: + ext = self._toc[key] + with self.fits as fits: + value = self.reader(fits, ext) + self._cache[key] = value + return value def __setitem__(self, key, value): # keys are always tuples @@ -838,15 +850,15 @@ class ClsFits(TocFits): """FITS-backed mapping for cls.""" tag = "CL" - columns = {"EXT": "10A", "NAME1": "10A", "NAME2": "10A", "BIN1": "I", "BIN2": "I"} + columns = {"NAME1": "10A", "NAME2": "10A", "BIN1": "I", "BIN2": "I"} reader = staticmethod(_read_twopoint) - writer = staticmethod(_write_twopoint) + writer = partial(_write_twopoint, name=tag) class MmsFits(TocFits): """FITS-backed mapping for mixing matrices.""" tag = "MM" - columns = {"EXT": "10A", "NAME": "10A", "BIN1": "I", "BIN2": "I"} + columns = {"NAME": "10A", "BIN1": "I", "BIN2": "I"} reader = staticmethod(_read_twopoint) - writer = staticmethod(_write_twopoint) + writer = partial(_write_twopoint, name=tag) diff --git a/tests/test_io.py b/tests/test_io.py index 38964ac..810fcf9 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -448,3 +448,33 @@ class TestFits(TocFits): assert tocfits2.toc == {(1, 2): "TEST0", (2, 2): "TEST1"} np.testing.assert_array_equal(tocfits2[1, 2], data12) np.testing.assert_array_equal(tocfits2[2, 2], data22) + + +def test_tocfits_is_lazy(tmp_path): + import fitsio + + from heracles.io import TocFits + + path = tmp_path / "bad.fits" + + # test keys(), values(), and items() are not eagerly reading data + tocfits = TocFits(path, clobber=True) + + # manually enter some non-existent rows into the ToC + assert tocfits._toc == {} + tocfits._toc[0,] = "BAD0" + tocfits._toc[1,] = "BAD1" + tocfits._toc[2,] = "BAD2" + + # these should not error + tocfits.keys() + tocfits.values() + tocfits.items() + + # make sure nothing is in the FITS + with fitsio.FITS(path, "r") as fits: + assert len(fits) == 2 + + # make sure there are errors when acualising the generators + with pytest.raises(OSError): + list(tocfits.values())