Skip to content

Commit

Permalink
API(io): output mixing matrices as structured arrays (#41)
Browse files Browse the repository at this point in the history
Changes `write_mms()` to output mixing matrices as structured two-point
data. This preserves any binning information if `binned_mms()` is used.

Closes: #39
  • Loading branch information
ntessore authored Sep 19, 2023
1 parent bc1fa51 commit 5501836
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 32 deletions.
70 changes: 39 additions & 31 deletions heracles/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,37 @@ def _read_metadata(hdu):
return md


def _as_twopoint(arr, name):
"""convert two-point data (i.e. one L column) to structured array"""

arr = np.asanyarray(arr)

# get the data into structured array if not already
if arr.dtype.names is None:
n, *dims = arr.shape
data = arr

dt = np.dtype(
[
("L", float),
(name, arr.dtype.str, dims) if dims else (name, arr.dtype.str),
("LMIN", float),
("LMAX", float),
("W", float),
],
metadata=dict(arr.dtype.metadata or {}),
)

arr = np.empty(n, dt)
arr["L"] = np.arange(n)
arr[name] = data
arr["LMIN"] = arr["L"]
arr["LMAX"] = arr["L"] + 1
arr["W"] = 1

return arr


def read_mask(mask_name, nside=None, field=0, extra_mask_name=None):
"""read visibility map from a HEALPix map file"""
mask = hp.read_map(mask_name, field=field)
Expand Down Expand Up @@ -397,27 +428,10 @@ def write_cls(filename, cls, *, clobber=False, workdir=".", include=None, exclud
ext = f"CL{cln}"
cln += 1

# get the data into the binned format if not already
if cl.dtype.names is None:
dt = np.dtype(
[
("L", float),
("CL", float),
("LMIN", float),
("LMAX", float),
("W", float),
],
metadata=dict(cl.dtype.metadata or {}),
)
cl_ = cl
cl = np.empty(len(cl_), dt)
cl["L"] = np.arange(len(cl_))
cl["CL"] = cl_
cl["LMIN"] = cl["L"]
cl["LMAX"] = cl["L"] + 1
cl["W"] = 1

# write the data column
# get the data into structured format if not already
cl = _as_twopoint(cl, "CL")

# write the data columns
fits.write_table(cl, extname=ext)

# write the metadata
Expand Down Expand Up @@ -519,17 +533,11 @@ def write_mms(filename, mms, *, clobber=False, workdir=".", include=None, exclud
ext = f"MM{mmn}"
mmn += 1

# write the mixing matrix as an image
fits.write_image(mm, extname=ext)
# get the data into structured format if not already
mm = _as_twopoint(mm, "MM")

# write the WCS
fits[ext].write_key("WCSAXES", 2)
fits[ext].write_key("CNAME1", "L_1")
fits[ext].write_key("CNAME2", "L_2")
fits[ext].write_key("CTYPE1", " ")
fits[ext].write_key("CTYPE2", " ")
fits[ext].write_key("CUNIT1", " ")
fits[ext].write_key("CUNIT2", " ")
# write the data columns
fits.write_table(mm, extname=ext)

# write the metadata
_write_metadata(fits[ext], mm.dtype.metadata)
Expand Down
6 changes: 5 additions & 1 deletion tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,11 @@ def test_write_read_mms(tmp_path):

assert mms_.keys() == mms.keys()
for key in mms:
np.testing.assert_array_equal(mms_[key], mms[key])
assert key in mms_
mm = mms_[key]
assert mm.dtype.names == ("L", "MM", "LMIN", "LMAX", "W")
np.testing.assert_array_equal(mm["L"], np.arange(len(mms[key])))
np.testing.assert_array_equal(mm["MM"], mms[key])


def test_write_read_cov(mock_cls, tmp_path):
Expand Down

0 comments on commit 5501836

Please sign in to comment.