diff --git a/heracles/io.py b/heracles/io.py index 2f88a40..d5a6e3d 100644 --- a/heracles/io.py +++ b/heracles/io.py @@ -61,6 +61,37 @@ def _read_metadata(hdu): return md +def _as_twopoint(arr, name): + """convert two-point data (i.e. one L column) to structured array""" + + arr = np.asanyarray(arr) + + # get the data into structured array if not already + if arr.dtype.names is None: + n, *dims = arr.shape + data = arr + + dt = np.dtype( + [ + ("L", float), + (name, arr.dtype.str, dims) if dims else (name, arr.dtype.str), + ("LMIN", float), + ("LMAX", float), + ("W", float), + ], + metadata=dict(arr.dtype.metadata or {}), + ) + + arr = np.empty(n, dt) + arr["L"] = np.arange(n) + arr[name] = data + arr["LMIN"] = arr["L"] + arr["LMAX"] = arr["L"] + 1 + arr["W"] = 1 + + return arr + + def read_mask(mask_name, nside=None, field=0, extra_mask_name=None): """read visibility map from a HEALPix map file""" mask = hp.read_map(mask_name, field=field) @@ -397,27 +428,10 @@ def write_cls(filename, cls, *, clobber=False, workdir=".", include=None, exclud ext = f"CL{cln}" cln += 1 - # get the data into the binned format if not already - if cl.dtype.names is None: - dt = np.dtype( - [ - ("L", float), - ("CL", float), - ("LMIN", float), - ("LMAX", float), - ("W", float), - ], - metadata=dict(cl.dtype.metadata or {}), - ) - cl_ = cl - cl = np.empty(len(cl_), dt) - cl["L"] = np.arange(len(cl_)) - cl["CL"] = cl_ - cl["LMIN"] = cl["L"] - cl["LMAX"] = cl["L"] + 1 - cl["W"] = 1 - - # write the data column + # get the data into structured format if not already + cl = _as_twopoint(cl, "CL") + + # write the data columns fits.write_table(cl, extname=ext) # write the metadata @@ -519,17 +533,11 @@ def write_mms(filename, mms, *, clobber=False, workdir=".", include=None, exclud ext = f"MM{mmn}" mmn += 1 - # write the mixing matrix as an image - fits.write_image(mm, extname=ext) + # get the data into structured format if not already + mm = _as_twopoint(mm, "MM") - # write the WCS - fits[ext].write_key("WCSAXES", 2) - fits[ext].write_key("CNAME1", "L_1") - fits[ext].write_key("CNAME2", "L_2") - fits[ext].write_key("CTYPE1", " ") - fits[ext].write_key("CTYPE2", " ") - fits[ext].write_key("CUNIT1", " ") - fits[ext].write_key("CUNIT2", " ") + # write the data columns + fits.write_table(mm, extname=ext) # write the metadata _write_metadata(fits[ext], mm.dtype.metadata) diff --git a/tests/test_io.py b/tests/test_io.py index d98d1c4..7939455 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -276,7 +276,11 @@ def test_write_read_mms(tmp_path): assert mms_.keys() == mms.keys() for key in mms: - np.testing.assert_array_equal(mms_[key], mms[key]) + assert key in mms_ + mm = mms_[key] + assert mm.dtype.names == ("L", "MM", "LMIN", "LMAX", "W") + np.testing.assert_array_equal(mm["L"], np.arange(len(mms[key]))) + np.testing.assert_array_equal(mm["MM"], mms[key]) def test_write_read_cov(mock_cls, tmp_path):