Skip to content

Commit

Permalink
Try to make write more robust to float128 errors in FITS
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobachetti committed Sep 19, 2023
1 parent 48a1a29 commit 501d375
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 18 deletions.
58 changes: 47 additions & 11 deletions stingray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,14 @@ def to_astropy_table(self, no_longdouble=False) -> Table:

for attr in array_attrs:
vals = np.asarray(getattr(self, attr))
if no_longdouble and isinstance(vals.flat[0], np.longdouble):
vals = vals.astype(float)
if no_longdouble:
vals = reduce_precision_if_extended(vals)
data[attr] = vals

ts = Table(data)
meta_dict = self.get_meta_dict()
if not no_longdouble:
for attr in meta_dict.keys():
vals = getattr(self, attr)
probe = vals
if isinstance(vals, Iterable) and len(vals) > 0 and not isinstance(vals, str):
probe = np.asarray(vals).flat[0]
if isinstance(probe, np.longdouble):
vals = vals.astype(float)
meta_dict[attr] = vals
meta_dict[attr] = reduce_precision_if_extended(vals)

ts.meta.update(meta_dict)

Expand Down Expand Up @@ -509,7 +502,6 @@ def write(self, filename: str, fmt: str = None) -> None:
Table({"a": [np.longdouble(3)]}).write(probe_file, format=fmt, overwrite=True)
CAN_SAVE_LONGD = True
os.unlink(probe_file)
raise ValueError("float128")
except ValueError as e:
if "float128" not in str(e): # pragma: no cover
raise
Expand Down Expand Up @@ -1414,3 +1406,47 @@ def interpret_times(time: TTime, mjdref: float = 0) -> tuple[npt.ArrayLike, floa
pass

raise ValueError(f"Unknown time format: {type(time)}")


def reduce_precision_if_extended(
x, probe_types=["float128", "float96", "float80"], destination=float
):
"""Reduce a number to a standard float if extended precision.
Ignore all non-float types.
Parameters
----------
x : float
The number to be reduced
Returns
-------
x_red : same type of input
The input, only reduce to ``float`` precision if ``np.float128``
Examples
--------
>>> x = 1.0
>>> val = reduce_precision_if_extended(x, probe_types=["float64"])
>>> val is x
True
>>> x = np.asanyarray(1.0).astype(int)
>>> val = reduce_precision_if_extended(x, probe_types=["float64"])
>>> val is x
True
>>> x = np.asanyarray([1.0]).astype(int)
>>> val = reduce_precision_if_extended(x, probe_types=["float64"])
>>> val is x
True
>>> x = np.asanyarray(1.0).astype(np.float64)
>>> reduce_precision_if_extended(x, probe_types=["float64"], destination=np.float32) is x
False
>>> x = np.asanyarray([1.0]).astype(np.float64)
>>> reduce_precision_if_extended(x, probe_types=["float64"], destination=np.float32) is x
False
"""
if any([t in str(np.obj2sctype(x)) for t in probe_types]):
x_ret = x.astype(destination)
return x_ret
return x
24 changes: 17 additions & 7 deletions stingray/lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from astropy.time import TimeDelta, Time
from astropy import units as u

from stingray.base import StingrayTimeseries
from stingray.base import StingrayTimeseries, reduce_precision_if_extended
import stingray.utils as utils
from stingray.exceptions import StingrayError
from stingray.gti import (
Expand Down Expand Up @@ -1596,10 +1596,10 @@ def from_lightkurve(lk, skip_checks=True):
def to_astropy_timeseries(self):
return self._to_astropy_object(kind="timeseries")

def to_astropy_table(self):
return self._to_astropy_object(kind="table")
def to_astropy_table(self, **kwargs):
return self._to_astropy_object(kind="table", **kwargs)

def _to_astropy_object(self, kind="table"):
def _to_astropy_object(self, kind="table", no_longdouble=False):
data = {}

for attr in [
Expand All @@ -1611,15 +1611,22 @@ def _to_astropy_object(self, kind="table"):
"_bin_hi",
]:
if hasattr(self, attr) and getattr(self, attr) is not None:
data[attr.lstrip("_")] = np.asarray(getattr(self, attr))
vals = np.asarray(getattr(self, attr))
if no_longdouble:
vals = reduce_precision_if_extended(vals)
data[attr.lstrip("_")] = vals

time_array = self.time
if no_longdouble:
time_array = reduce_precision_if_extended(time_array)

if kind.lower() == "table":
data["time"] = self.time
data["time"] = time_array
ts = Table(data)
elif kind.lower() == "timeseries":
from astropy.timeseries import TimeSeries

ts = TimeSeries(data=data, time=TimeDelta(self.time * u.s))
ts = TimeSeries(data=data, time=TimeDelta(time_array * u.s))
else: # pragma: no cover
raise ValueError("Invalid kind (accepted: table or timeseries)")

Expand All @@ -1634,6 +1641,9 @@ def _to_astropy_object(self, kind="table"):
"err_dist",
]:
if hasattr(self, attr) and getattr(self, attr) is not None:
vals = getattr(self, attr)
if no_longdouble:
vals = reduce_precision_if_extended(vals)
ts.meta[attr.lstrip("_")] = getattr(self, attr)

return ts
Expand Down
6 changes: 6 additions & 0 deletions stingray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@
HAS_PYFFTW = False


try:
np.float128
HAS_128 = True
except AttributeError:
HAS_128 = False

Check warning on line 46 in stingray/utils.py

View check run for this annotation

Codecov / codecov/patch

stingray/utils.py#L45-L46

Added lines #L45 - L46 were not covered by tests

# If numba is installed, import jit. Otherwise, define an empty decorator with
# the same name.
try:
Expand Down

0 comments on commit 501d375

Please sign in to comment.