Skip to content

Commit

Permalink
ENH(io): TocFits class for FITS-backed tocdicts (#53)
Browse files Browse the repository at this point in the history
Add a `TocFits` class that implements a FITS-backed mapping compatible
with `TocDict`. The base class is generic, concrete implementations are
`AlmFits`, `ClsFits`, and `MmsFits`.

Closes: #47
  • Loading branch information
ntessore authored Nov 11, 2023
1 parent abae466 commit 880acb3
Show file tree
Hide file tree
Showing 3 changed files with 358 additions and 59 deletions.
264 changes: 214 additions & 50 deletions heracles/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@

import logging
import os
from collections.abc import MutableMapping
from functools import partial
from pathlib import Path
from types import MappingProxyType
from weakref import WeakValueDictionary

import fitsio
import healpy as hp
Expand Down Expand Up @@ -61,7 +66,29 @@ def _read_metadata(hdu):
return md


def _as_twopoint(arr, name):
def _write_complex(fits, ext, arr):
"""write complex-valued data to FITS table"""
# write the data
fits.write_table([arr.real, arr.imag], names=["real", "imag"], extname=ext)

# write the metadata
_write_metadata(fits[ext], arr.dtype.metadata)


def _read_complex(fits, ext):
"""read complex-valued data from FITS table"""
# read structured data as complex array
raw = fits[ext].read()
arr = np.empty(len(raw), dtype=complex)
arr.real = raw["real"]
arr.imag = raw["imag"]
del raw
# read and attach metadata
arr.dtype = np.dtype(arr.dtype, metadata=_read_metadata(fits[ext]))
return arr


def _write_twopoint(fits, ext, arr, name):
"""convert two-point data (i.e. one L column) to structured array"""

arr = np.asanyarray(arr)
Expand Down Expand Up @@ -89,6 +116,19 @@ def _as_twopoint(arr, name):
arr["LMAX"] = arr["L"] + 1
arr["W"] = 1

# write the twopoint data
fits.write_table(arr, extname=ext)

# write the metadata
_write_metadata(fits[ext], arr.dtype.metadata)


def _read_twopoint(fits, ext):
"""read two-point data from FITS"""
# read data from extension
arr = fits[ext].read()
# read and attach metadata
arr.dtype = np.dtype(arr.dtype, metadata=_read_metadata(fits[ext]))
return arr


Expand Down Expand Up @@ -322,11 +362,8 @@ def write_alms(
ext = f"ALM{almn}"
almn += 1

# write the data
fits.write_table([alm.real, alm.imag], names=["real", "imag"], extname=ext)

# write the metadata
_write_metadata(fits[ext], alm.dtype.metadata)
# write the alm as structured data with metadata
_write_complex(fits, ext, alm)

# write the TOC entry
tocentry[0] = (ext, n, i)
Expand Down Expand Up @@ -361,18 +398,8 @@ def read_alms(filename, workdir=".", *, include=None, exclude=None):

logger.info("reading %s alm for bin %s", n, i)

# read the alm from the extension
raw = fits[ext].read()
alm = np.empty(len(raw), dtype=complex)
alm.real = raw["real"]
alm.imag = raw["imag"]
del raw

# read and attach metadata
alm.dtype = np.dtype(alm.dtype, metadata=_read_metadata(fits[ext]))

# store in set of alms
alms[n, i] = alm
# read the alm from the extension and store in set of alms
alms[n, i] = _read_complex(fits, ext)

logger.info("done with %d alms", len(alms))

Expand Down Expand Up @@ -428,14 +455,8 @@ def write_cls(filename, cls, *, clobber=False, workdir=".", include=None, exclud
ext = f"CL{cln}"
cln += 1

# get the data into structured format if not already
cl = _as_twopoint(cl, "CL")

# write the data columns
fits.write_table(cl, extname=ext)

# write the metadata
_write_metadata(fits[ext], cl.dtype.metadata)
# write the data in structured format
_write_twopoint(fits, ext, cl, "CL")

# write the TOC entry
tocentry[0] = (ext, k1, k2, i1, i2)
Expand Down Expand Up @@ -470,14 +491,8 @@ def read_cls(filename, workdir=".", *, include=None, exclude=None):

logger.info("reading %s x %s cl for bins %s, %s", k1, k2, i1, i2)

# read the cl from the extension
cl = fits[ext].read()

# read and attach metadata
cl.dtype = np.dtype(cl.dtype, metadata=_read_metadata(fits[ext]))

# store in set of cls
cls[k1, k2, i1, i2] = cl
# read the cl from the extension and store in set of cls
cls[k1, k2, i1, i2] = _read_twopoint(fits, ext)

logger.info("done with %d cls", len(cls))

Expand Down Expand Up @@ -533,14 +548,8 @@ def write_mms(filename, mms, *, clobber=False, workdir=".", include=None, exclud
ext = f"MM{mmn}"
mmn += 1

# get the data into structured format if not already
mm = _as_twopoint(mm, "MM")

# write the data columns
fits.write_table(mm, extname=ext)

# write the metadata
_write_metadata(fits[ext], mm.dtype.metadata)
# write the data in structured format
_write_twopoint(fits, ext, mm, "MM")

# write the TOC entry
tocentry[0] = (ext, n, i1, i2)
Expand Down Expand Up @@ -575,14 +584,8 @@ def read_mms(filename, workdir=".", *, include=None, exclude=None):

logger.info("reading mixing matrix %s for bins %s, %s", n, i1, i2)

# read the mixing matrix from the extension
mm = fits[ext].read()

# read and attach metadata
mm.dtype = np.dtype(mm.dtype, metadata=_read_metadata(fits[ext]))

# store in set of mms
mms[n, i1, i2] = mm
# read the mixing matrix from the extension and store in set of mms
mms[n, i1, i2] = _read_twopoint(fits, ext)

logger.info("done with %d mm(s)", len(mms))

Expand Down Expand Up @@ -711,3 +714,164 @@ def read_cov(filename, workdir=".", *, include=None, exclude=None):

# return the toc dict of covariances
return cov


class TocFits(MutableMapping):
"""A FITS-backed TocDict."""

tag = "EXT"
"""Tag for FITS extensions."""

columns = {}
"""Columns and their formats in the FITS table of contents."""

@staticmethod
def reader(fits, ext):
"""Read data from FITS extension."""
return fits[ext].read()

@staticmethod
def writer(fits, ext, data):
"""Write data to FITS extension."""
if data.dtype.names is None:
msg = "data must be structured array"
raise TypeError(msg)
fits.write_table(data, extname=ext)

@property
def fits(self):
"""Return an opened FITS context manager."""
return fitsio.FITS(self.path, mode="rw", clobber=False)

@property
def toc(self):
"""Return a view of the FITS table of contents."""
return MappingProxyType(self._toc)

def __init__(self, path, *, clobber=False):
self.path = Path(path)

# FITS extension for table of contents
self.ext = f"{self.tag.upper()}TOC"

# if new or overwriting, create an empty FITS with primary HDU
if not self.path.exists() or clobber:
with fitsio.FITS(self.path, mode="rw", clobber=True) as fits:
fits.write(None)

# reopen FITS for writing data
with self.fits as fits:
# write a new ToC extension if FITS doesn't already contain one
if self.ext not in fits:
fits.create_table_hdu(
names=["EXT", *self.columns.keys()],
formats=["10A", *self.columns.values()],
extname=self.ext,
)

# get the dtype for ToC entries
self.dtype = fits[self.ext].get_rec_dtype()[0]

# empty ToC
self._toc = TocDict()
else:
# read existing ToC from FITS
toc = fits[self.ext].read()

# store the dtype for ToC entries
toc.dtype = toc.dtype

# store the ToC as a mapping
self._toc = TocDict({tuple(key): str(ext) for ext, *key in toc})

# set up a weakly-referenced cache for extension data
self._cache = WeakValueDictionary()

def __len__(self):
return len(self._toc)

def __iter__(self):
return iter(self._toc)

def __contains__(self, key):
if not isinstance(key, tuple):
key = (key,)
return key in self._toc

def __getitem__(self, key):
ext = self._toc[key]

# if a TocDict is returned, we have the result of a selection
if isinstance(ext, TocDict):
# make a new instance and copy attributes
selected = object.__new__(self.__class__)
selected.path = self.path
# shared cache since both instances read the same file
selected._cache = self._cache
# the new toc contains the result of the selection
selected._toc = ext
return selected

# a specific extension was requested, fetch data
data = self._cache.get(ext)
if data is None:
with self.fits as fits:
data = self.reader(fits, ext)
self._cache[ext] = data
return data

def __setitem__(self, key, value):
# keys are always tuples
if not isinstance(key, tuple):
key = (key,)

# check if an extension with the given key already exists
# otherwise, get the first free extension with the given tag
if key in self._toc:
ext = self._toc[key]
else:
extn = len(self._toc)
exts = set(self._toc.values())
while (ext := f"{self.tag.upper()}{extn}") in exts:
extn += 1

# write data using the class writer, and update ToC as necessary
with self.fits as fits:
self.writer(fits, ext, value)
if key not in self._toc:
tocentry = np.empty(1, dtype=self.dtype)
tocentry[0] = (ext, *key)
fits[self.ext].append(tocentry)
self._toc[key] = ext

def __delitem__(self, key):
# fitsio does not support deletion of extensions
msg = "deleting FITS extensions is not supported"
raise NotImplementedError(msg)


class AlmFits(TocFits):
"""FITS-backed mapping for alms."""

tag = "ALM"
columns = {"NAME": "10A", "BIN": "I"}
reader = staticmethod(_read_complex)
writer = staticmethod(_write_complex)


class ClsFits(TocFits):
"""FITS-backed mapping for cls."""

tag = "CL"
columns = {"NAME1": "10A", "NAME2": "10A", "BIN1": "I", "BIN2": "I"}
reader = staticmethod(_read_twopoint)
writer = partial(_write_twopoint, name=tag)


class MmsFits(TocFits):
"""FITS-backed mapping for mixing matrices."""

tag = "MM"
columns = {"NAME": "10A", "BIN1": "I", "BIN2": "I"}
reader = staticmethod(_read_twopoint)
writer = partial(_write_twopoint, name=tag)
Loading

0 comments on commit 880acb3

Please sign in to comment.