Skip to content

Commit

Permalink
Try to make write more robust to float128 errors in FITS
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobachetti committed Sep 19, 2023
1 parent 081d53e commit e700ebb
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 20 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ filterwarnings =
ignore:.*is a deprecated alias for:DeprecationWarning
ignore:.*HIERARCH card will be created.*:
ignore:.*FigureCanvasAgg is non-interactive.*:UserWarning
ignore:.*Converting to lower precision.*:UserWarning

;addopts = --disable-warnings

Expand Down
129 changes: 116 additions & 13 deletions stingray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pickle
import warnings
import copy
import os

import numpy as np
from astropy.table import Table
Expand All @@ -23,6 +24,52 @@
TTime = Union[Time, TimeDelta, Quantity, npt.ArrayLike]
Tso = TypeVar("Tso", bound="StingrayObject")

try:
np.float128
HAS_128 = True
except AttributeError:
HAS_128 = False

Check warning on line 31 in stingray/base.py

View check run for this annotation

Codecov / codecov/patch

stingray/base.py#L30-L31

Added lines #L30 - L31 were not covered by tests


def _can_save_longdouble(probe_file: str, fmt: str) -> bool:
"""Check if a given file format can save tables with longdoubles."""
if not HAS_128:
# There are no known issues with saving longdoubles where numpy.float128 is not defined
return True

Check warning on line 38 in stingray/base.py

View check run for this annotation

Codecov / codecov/patch

stingray/base.py#L38

Added line #L38 was not covered by tests

try:
Table({"a": np.arange(0, 3, 1.212314).astype(np.float128)}).write(
probe_file, format=fmt, overwrite=True
)
yes_it_can = True
os.unlink(probe_file)
except ValueError as e:
if "float128" not in str(e): # pragma: no cover
raise
warnings.warn(
f"{fmt} output does not allow saving metadata at maximum precision. "
"Converting to lower precision"
)
yes_it_can = False
return yes_it_can


def _can_serialize_meta(probe_file: str, fmt: str) -> bool:
try:
Table({"a": [3]}).write(probe_file, overwrite=True, format=fmt, serialize_meta=True)

os.unlink(probe_file)
yes_it_can = True
except TypeError as e:
if "serialize_meta" not in str(e):
raise

Check warning on line 65 in stingray/base.py

View check run for this annotation

Codecov / codecov/patch

stingray/base.py#L65

Added line #L65 was not covered by tests
warnings.warn(
f"{fmt} output does not serialize the metadata at the moment. "
"Some attributes will be lost."
)
yes_it_can = False
return yes_it_can


def sqsum(array1, array2):
"""Return the square root of the sum of the squares of two arrays."""
Expand Down Expand Up @@ -146,12 +193,13 @@ def __eq__(self, other_ts):
raise ValueError(f"{type(self)} can only be compared with a {type(self)} Object")

for attr in self.meta_attrs():
if isinstance(getattr(self, attr), np.ndarray):
if not np.array_equal(getattr(self, attr), getattr(other_ts, attr)):
if np.isscalar(getattr(self, attr)):
if not getattr(self, attr) == getattr(other_ts, attr):
return False
else:
if not getattr(self, attr) == getattr(other_ts, attr):
if not np.array_equal(getattr(self, attr), getattr(other_ts, attr)):
return False

for attr in self.array_attrs():
if not np.array_equal(getattr(self, attr), getattr(other_ts, attr)):
return False
Expand All @@ -178,7 +226,7 @@ def get_meta_dict(self) -> dict:
meta_dict[key] = val
return meta_dict

def to_astropy_table(self) -> Table:
def to_astropy_table(self, no_longdouble=False) -> Table:
"""Create an Astropy Table from a ``StingrayObject``
Array attributes (e.g. ``time``, ``pi``, ``energy``, etc. for
Expand All @@ -189,11 +237,18 @@ def to_astropy_table(self) -> Table:
array_attrs = self.array_attrs() + [self.main_array_attr]

for attr in array_attrs:
data[attr] = np.asarray(getattr(self, attr))
vals = np.asarray(getattr(self, attr))
if no_longdouble:
vals = reduce_precision_if_extended(vals)
data[attr] = vals

ts = Table(data)
meta_dict = self.get_meta_dict()
for attr in meta_dict.keys():
if no_longdouble:
meta_dict[attr] = reduce_precision_if_extended(meta_dict[attr])

ts.meta.update(self.get_meta_dict())
ts.meta.update(meta_dict)

return ts

Expand Down Expand Up @@ -489,21 +544,25 @@ def write(self, filename: str, fmt: str = None) -> None:
elif fmt.lower() == "ascii":
fmt = "ascii.ecsv"

ts = self.to_astropy_table()
probe_file = "probe.bu.bu." + filename[-7:]

CAN_SAVE_LONGD = _can_save_longdouble(probe_file, fmt)
CAN_SERIALIZE_META = _can_serialize_meta(probe_file, fmt)

to_be_saved = self

ts = to_be_saved.to_astropy_table(no_longdouble=not CAN_SAVE_LONGD)

if fmt is None or "ascii" in fmt:
for col in ts.colnames:
if np.iscomplex(ts[col].flatten()[0]):
ts[f"{col}.real"] = ts[col].real
ts[f"{col}.imag"] = ts[col].imag
ts.remove_column(col)

try:
if CAN_SERIALIZE_META:
ts.write(filename, format=fmt, overwrite=True, serialize_meta=True)
except TypeError as e:
warnings.warn(
f"{fmt} output does not serialize the metadata at the moment. "
"Some attributes will be lost."
)
else:
ts.write(filename, format=fmt, overwrite=True)

def apply_mask(self, mask: npt.ArrayLike, inplace: bool = False, filtered_attrs: list = None):
Expand Down Expand Up @@ -1371,3 +1430,47 @@ def interpret_times(time: TTime, mjdref: float = 0) -> tuple[npt.ArrayLike, floa
pass

raise ValueError(f"Unknown time format: {type(time)}")


def reduce_precision_if_extended(
x, probe_types=["float128", "float96", "float80", "longdouble"], destination=float
):
"""Reduce a number to a standard float if extended precision.
Ignore all non-float types.
Parameters
----------
x : float
The number to be reduced
Returns
-------
x_red : same type of input
The input, only reduce to ``float`` precision if ``np.float128``
Examples
--------
>>> x = 1.0
>>> val = reduce_precision_if_extended(x, probe_types=["float64"])
>>> val is x
True
>>> x = np.asanyarray(1.0).astype(int)
>>> val = reduce_precision_if_extended(x, probe_types=["float64"])
>>> val is x
True
>>> x = np.asanyarray([1.0]).astype(int)
>>> val = reduce_precision_if_extended(x, probe_types=["float64"])
>>> val is x
True
>>> x = np.asanyarray(1.0).astype(np.float64)
>>> reduce_precision_if_extended(x, probe_types=["float64"], destination=np.float32) is x
False
>>> x = np.asanyarray([1.0]).astype(np.float64)
>>> reduce_precision_if_extended(x, probe_types=["float64"], destination=np.float32) is x
False
"""
if any([t in str(np.obj2sctype(x)) for t in probe_types]):
x_ret = x.astype(destination)
return x_ret
return x
24 changes: 17 additions & 7 deletions stingray/lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from astropy.time import TimeDelta, Time
from astropy import units as u

from stingray.base import StingrayTimeseries
from stingray.base import StingrayTimeseries, reduce_precision_if_extended
import stingray.utils as utils
from stingray.exceptions import StingrayError
from stingray.gti import (
Expand Down Expand Up @@ -1596,10 +1596,10 @@ def from_lightkurve(lk, skip_checks=True):
def to_astropy_timeseries(self):
return self._to_astropy_object(kind="timeseries")

def to_astropy_table(self):
return self._to_astropy_object(kind="table")
def to_astropy_table(self, **kwargs):
return self._to_astropy_object(kind="table", **kwargs)

def _to_astropy_object(self, kind="table"):
def _to_astropy_object(self, kind="table", no_longdouble=False):
data = {}

for attr in [
Expand All @@ -1611,15 +1611,22 @@ def _to_astropy_object(self, kind="table"):
"_bin_hi",
]:
if hasattr(self, attr) and getattr(self, attr) is not None:
data[attr.lstrip("_")] = np.asarray(getattr(self, attr))
vals = np.asarray(getattr(self, attr))
if no_longdouble:
vals = reduce_precision_if_extended(vals)
data[attr.lstrip("_")] = vals

time_array = self.time
if no_longdouble:
time_array = reduce_precision_if_extended(time_array)

if kind.lower() == "table":
data["time"] = self.time
data["time"] = time_array
ts = Table(data)
elif kind.lower() == "timeseries":
from astropy.timeseries import TimeSeries

ts = TimeSeries(data=data, time=TimeDelta(self.time * u.s))
ts = TimeSeries(data=data, time=TimeDelta(time_array * u.s))
else: # pragma: no cover
raise ValueError("Invalid kind (accepted: table or timeseries)")

Expand All @@ -1634,6 +1641,9 @@ def _to_astropy_object(self, kind="table"):
"err_dist",
]:
if hasattr(self, attr) and getattr(self, attr) is not None:
vals = getattr(self, attr)
if no_longdouble:
vals = reduce_precision_if_extended(vals)
ts.meta[attr.lstrip("_")] = getattr(self, attr)

return ts
Expand Down

0 comments on commit e700ebb

Please sign in to comment.