Skip to content

Commit

Permalink
FITS-backed mappings
Browse files Browse the repository at this point in the history
  • Loading branch information
ntessore committed Oct 10, 2023
1 parent abae466 commit 437848c
Show file tree
Hide file tree
Showing 2 changed files with 262 additions and 50 deletions.
239 changes: 189 additions & 50 deletions heracles/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

import logging
import os
from collections.abc import MutableMapping
from pathlib import Path
from types import MappingProxyType

import fitsio
import healpy as hp
Expand Down Expand Up @@ -61,7 +64,29 @@ def _read_metadata(hdu):
return md


def _as_twopoint(arr, name):
def _write_complex(fits, ext, arr):
"""write complex-valued data to FITS table"""
# write the data
fits.write_table([arr.real, arr.imag], names=["real", "imag"], extname=ext)

# write the metadata
_write_metadata(fits[ext], arr.dtype.metadata)


def _read_complex(fits, ext):
"""read complex-valued data from FITS table"""
# read structured data as complex array
raw = fits[ext].read()
arr = np.empty(len(raw), dtype=complex)
arr.real = raw["real"]
arr.imag = raw["imag"]
del raw
# read and attach metadata
arr.dtype = np.dtype(arr.dtype, metadata=_read_metadata(fits[ext]))
return arr


def _write_twopoint(fits, ext, arr, name):
"""convert two-point data (i.e. one L column) to structured array"""

arr = np.asanyarray(arr)
Expand Down Expand Up @@ -89,6 +114,19 @@ def _as_twopoint(arr, name):
arr["LMAX"] = arr["L"] + 1
arr["W"] = 1

# write the twopoint data
fits.write_table(arr, extname=ext)

# write the metadata
_write_metadata(fits[ext], arr.dtype.metadata)


def _read_twopoint(fits, ext):
"""read two-point data from FITS"""
# read data from extension
arr = fits[ext].read()
# read and attach metadata
arr.dtype = np.dtype(arr.dtype, metadata=_read_metadata(fits[ext]))
return arr


Expand Down Expand Up @@ -322,11 +360,8 @@ def write_alms(
ext = f"ALM{almn}"
almn += 1

# write the data
fits.write_table([alm.real, alm.imag], names=["real", "imag"], extname=ext)

# write the metadata
_write_metadata(fits[ext], alm.dtype.metadata)
# write the alm as structured data with metadata
_write_complex(fits, ext, alm)

# write the TOC entry
tocentry[0] = (ext, n, i)
Expand Down Expand Up @@ -361,18 +396,8 @@ def read_alms(filename, workdir=".", *, include=None, exclude=None):

logger.info("reading %s alm for bin %s", n, i)

# read the alm from the extension
raw = fits[ext].read()
alm = np.empty(len(raw), dtype=complex)
alm.real = raw["real"]
alm.imag = raw["imag"]
del raw

# read and attach metadata
alm.dtype = np.dtype(alm.dtype, metadata=_read_metadata(fits[ext]))

# store in set of alms
alms[n, i] = alm
# read the alm from the extension and store in set of alms
alms[n, i] = _read_complex(fits, ext)

logger.info("done with %d alms", len(alms))

Expand Down Expand Up @@ -428,14 +453,8 @@ def write_cls(filename, cls, *, clobber=False, workdir=".", include=None, exclud
ext = f"CL{cln}"
cln += 1

# get the data into structured format if not already
cl = _as_twopoint(cl, "CL")

# write the data columns
fits.write_table(cl, extname=ext)

# write the metadata
_write_metadata(fits[ext], cl.dtype.metadata)
# write the data in structured format
_write_twopoint(fits, ext, cl, "CL")

# write the TOC entry
tocentry[0] = (ext, k1, k2, i1, i2)
Expand Down Expand Up @@ -470,14 +489,8 @@ def read_cls(filename, workdir=".", *, include=None, exclude=None):

logger.info("reading %s x %s cl for bins %s, %s", k1, k2, i1, i2)

# read the cl from the extension
cl = fits[ext].read()

# read and attach metadata
cl.dtype = np.dtype(cl.dtype, metadata=_read_metadata(fits[ext]))

# store in set of cls
cls[k1, k2, i1, i2] = cl
# read the cl from the extension and store in set of cls
cls[k1, k2, i1, i2] = _read_twopoint(fits, ext)

logger.info("done with %d cls", len(cls))

Expand Down Expand Up @@ -533,14 +546,8 @@ def write_mms(filename, mms, *, clobber=False, workdir=".", include=None, exclud
ext = f"MM{mmn}"
mmn += 1

# get the data into structured format if not already
mm = _as_twopoint(mm, "MM")

# write the data columns
fits.write_table(mm, extname=ext)

# write the metadata
_write_metadata(fits[ext], mm.dtype.metadata)
# write the data in structured format
_write_twopoint(fits, ext, mm, "MM")

# write the TOC entry
tocentry[0] = (ext, n, i1, i2)
Expand Down Expand Up @@ -575,14 +582,8 @@ def read_mms(filename, workdir=".", *, include=None, exclude=None):

logger.info("reading mixing matrix %s for bins %s, %s", n, i1, i2)

# read the mixing matrix from the extension
mm = fits[ext].read()

# read and attach metadata
mm.dtype = np.dtype(mm.dtype, metadata=_read_metadata(fits[ext]))

# store in set of mms
mms[n, i1, i2] = mm
# read the mixing matrix from the extension and store in set of mms
mms[n, i1, i2] = _read_twopoint(fits, ext)

logger.info("done with %d mm(s)", len(mms))

Expand Down Expand Up @@ -711,3 +712,141 @@ def read_cov(filename, workdir=".", *, include=None, exclude=None):

# return the toc dict of covariances
return cov


class TocFits(MutableMapping):
"""A FITS-backed TocDict."""

tag = ""
"""Tag for FITS extensions."""

columns = {}
"""Columns and formats in the FITS table of contents."""

@staticmethod
def reader(fits, ext):
"""Read data from FITS extension."""
return fits[ext].read()

@staticmethod
def writer(fits, ext, data):
"""Write data to FITS extension."""
if data.dtype.names is None:
msg = "data must be structured array"
raise TypeError(msg)
fits.write_table(data, extname=ext)

@property
def fits(self):
"""Return an opened FITS context manager."""
return fitsio.FITS(self.path, mode="rw", clobber=False)

@property
def toc(self):
"""Return a view of the FITS table of contents."""
return MappingProxyType(self._toc)

def __init__(self, path, *, clobber=False):
self.path = Path(path)

# FITS extension for table of contents
self.ext = f"{self.tag.upper()}TOC"

# if new or overwriting, create an empty FITS with primary HDU
if not self.path.exists() or clobber:
with fitsio.FITS(self.path, mode="rw", clobber=True) as fits:
fits.write(None)

# reopen FITS for writing data
with self.fits as fits:
# write a new ToC extension if FITS doesn't already contain one
if self.ext not in fits:
fits.create_table_hdu(
names=["EXT", *self.columns.keys()],
formats=["10A", *self.columns.values()],
extname=self.ext,
)

# get the dtype for ToC entries
self.dtype = fits[self.ext].get_rec_dtype()[0]

# empty ToC
self._toc = {}
else:
# read existing ToC from FITS
toc = fits[self.ext].read()

# store the dtype for ToC entries
toc.dtype = toc.dtype

# store the ToC as a mapping
self._toc = {tuple(key): str(ext) for ext, *key in toc}

def __len__(self):
return len(self._toc)

def __iter__(self):
return iter(self._toc)

def __getitem__(self, key):
if not isinstance(key, tuple):
key = (key,)
ext = self._toc[key]
with self.fits as fits:
return self.reader(fits, ext)

def __setitem__(self, key, value):
# keys are always tuples
if not isinstance(key, tuple):
key = (key,)

# check if an extension with the given key already exists
# otherwise, get the first free extension with the given tag
if key in self._toc:
ext = self._toc[key]
else:
extn = len(self._toc)
exts = set(self._toc.values())
while (ext := f"{self.tag.upper()}{extn}") in exts:
extn += 1

# write data using the class writer, and update ToC as necessary
with self.fits as fits:
self.writer(fits, ext, value)
if key not in self._toc:
tocentry = np.empty(1, dtype=self.dtype)
tocentry[0] = (ext, *key)
fits[self.ext].append(tocentry)
self._toc[key] = ext

def __delitem__(self, key):
# fitsio does not support deletion of extensions
msg = "deleting FITS extensions is not supported"
raise NotImplementedError(msg)


class AlmFits(TocFits):
"""FITS-backed mapping for alms."""

tag = "ALM"
columns = {"NAME": "10A", "BIN": "I"}
reader = staticmethod(_read_complex)
writer = staticmethod(_write_complex)


class ClsFits(TocFits):
"""FITS-backed mapping for cls."""

tag = "CL"
columns = {"EXT": "10A", "NAME1": "10A", "NAME2": "10A", "BIN1": "I", "BIN2": "I"}
reader = staticmethod(_read_twopoint)
writer = staticmethod(_write_twopoint)


class MmsFits(TocFits):
"""FITS-backed mapping for mixing matrices."""

tag = "MM"
columns = {"EXT": "10A", "NAME": "10A", "BIN1": "I", "BIN2": "I"}
reader = staticmethod(_read_twopoint)
writer = staticmethod(_write_twopoint)
73 changes: 73 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,76 @@ def test_read_mask_extra(
extra_mask_name=mock_writemask_extra,
)
assert (mask == maps[:, ibin] * maps_extra[:]).all()


def test_tocfits(tmp_path):
import fitsio
import numpy as np

from heracles.io import TocFits

class TestFits(TocFits):
tag = "test"
columns = {"col1": "I", "col2": "J"}

path = tmp_path / "test.fits"

assert not path.exists()

tocfits = TestFits(path, clobber=True)

assert path.exists()

with fitsio.FITS(path) as fits:
assert len(fits) == 2
toc = fits["TESTTOC"].read()
assert toc.dtype.names == ("EXT", "col1", "col2")
assert len(toc) == 0

assert len(tocfits) == 0
assert list(tocfits) == []
assert tocfits.toc == {}

data12 = np.zeros(5, dtype=[("X", float), ("Y", int)])
data22 = np.ones(5, dtype=[("X", float), ("Y", int)])

tocfits[1, 2] = data12

with fitsio.FITS(path) as fits:
assert len(fits) == 3
toc = fits["TESTTOC"].read()
assert len(toc) == 1
np.testing.assert_array_equal(fits["TEST0"].read(), data12)

assert len(tocfits) == 1
assert list(tocfits) == [(1, 2)]
assert tocfits.toc == {(1, 2): "TEST0"}
np.testing.assert_array_equal(tocfits[1, 2], data12)

tocfits[2, 2] = data22

with fitsio.FITS(path) as fits:
assert len(fits) == 4
toc = fits["TESTTOC"].read()
assert len(toc) == 2
np.testing.assert_array_equal(fits["TEST0"].read(), data12)
np.testing.assert_array_equal(fits["TEST1"].read(), data22)

assert len(tocfits) == 2
assert list(tocfits) == [(1, 2), (2, 2)]
assert tocfits.toc == {(1, 2): "TEST0", (2, 2): "TEST1"}
np.testing.assert_array_equal(tocfits[1, 2], data12)
np.testing.assert_array_equal(tocfits[2, 2], data22)

with pytest.raises(NotImplementedError):
del tocfits[1, 2]

del tocfits

tocfits2 = TestFits(path, clobber=False)

assert len(tocfits2) == 2
assert list(tocfits2) == [(1, 2), (2, 2)]
assert tocfits2.toc == {(1, 2): "TEST0", (2, 2): "TEST1"}
np.testing.assert_array_equal(tocfits2[1, 2], data12)
np.testing.assert_array_equal(tocfits2[2, 2], data22)

0 comments on commit 437848c

Please sign in to comment.