diff --git a/stingray/base.py b/stingray/base.py index ce0b5c04e..5c6035579 100644 --- a/stingray/base.py +++ b/stingray/base.py @@ -191,21 +191,14 @@ def to_astropy_table(self, no_longdouble=False) -> Table: for attr in array_attrs: vals = np.asarray(getattr(self, attr)) - if no_longdouble and isinstance(vals.flat[0], np.longdouble): - vals = vals.astype(float) + if no_longdouble: + vals = reduce_precision_if_extended(vals) data[attr] = vals ts = Table(data) meta_dict = self.get_meta_dict() if not no_longdouble: - for attr in meta_dict.keys(): - vals = getattr(self, attr) - probe = vals - if isinstance(vals, Iterable) and len(vals) > 0 and not isinstance(vals, str): - probe = np.asarray(vals).flat[0] - if isinstance(probe, np.longdouble): - vals = vals.astype(float) - meta_dict[attr] = vals + meta_dict[attr] = reduce_precision_if_extended(vals) ts.meta.update(meta_dict) @@ -509,7 +502,6 @@ def write(self, filename: str, fmt: str = None) -> None: Table({"a": [np.longdouble(3)]}).write(probe_file, format=fmt, overwrite=True) CAN_SAVE_LONGD = True os.unlink(probe_file) - raise ValueError("float128") except ValueError as e: if "float128" not in str(e): # pragma: no cover raise @@ -1414,3 +1406,47 @@ def interpret_times(time: TTime, mjdref: float = 0) -> tuple[npt.ArrayLike, floa pass raise ValueError(f"Unknown time format: {type(time)}") + + +def reduce_precision_if_extended( + x, probe_types=["float128", "float96", "float80"], destination=float +): + """Reduce a number to a standard float if extended precision. + + Ignore all non-float types. + + Parameters + ---------- + x : float + The number to be reduced + + Returns + ------- + x_red : same type of input + The input, only reduce to ``float`` precision if ``np.float128`` + + Examples + -------- + >>> x = 1.0 + >>> val = reduce_precision_if_extended(x, probe_types=["float64"]) + >>> val is x + True + >>> x = np.asanyarray(1.0).astype(int) + >>> val = reduce_precision_if_extended(x, probe_types=["float64"]) + >>> val is x + True + >>> x = np.asanyarray([1.0]).astype(int) + >>> val = reduce_precision_if_extended(x, probe_types=["float64"]) + >>> val is x + True + >>> x = np.asanyarray(1.0).astype(np.float64) + >>> reduce_precision_if_extended(x, probe_types=["float64"], destination=np.float32) is x + False + >>> x = np.asanyarray([1.0]).astype(np.float64) + >>> reduce_precision_if_extended(x, probe_types=["float64"], destination=np.float32) is x + False + """ + if any([t in str(np.obj2sctype(x)) for t in probe_types]): + x_ret = x.astype(destination) + return x_ret + return x diff --git a/stingray/lightcurve.py b/stingray/lightcurve.py index 089addd7d..be4a70d82 100644 --- a/stingray/lightcurve.py +++ b/stingray/lightcurve.py @@ -17,7 +17,7 @@ from astropy.time import TimeDelta, Time from astropy import units as u -from stingray.base import StingrayTimeseries +from stingray.base import StingrayTimeseries, reduce_precision_if_extended import stingray.utils as utils from stingray.exceptions import StingrayError from stingray.gti import ( @@ -1596,10 +1596,10 @@ def from_lightkurve(lk, skip_checks=True): def to_astropy_timeseries(self): return self._to_astropy_object(kind="timeseries") - def to_astropy_table(self): - return self._to_astropy_object(kind="table") + def to_astropy_table(self, **kwargs): + return self._to_astropy_object(kind="table", **kwargs) - def _to_astropy_object(self, kind="table"): + def _to_astropy_object(self, kind="table", no_longdouble=False): data = {} for attr in [ @@ -1611,15 +1611,22 @@ def _to_astropy_object(self, kind="table"): "_bin_hi", ]: if hasattr(self, attr) and getattr(self, attr) is not None: - data[attr.lstrip("_")] = np.asarray(getattr(self, attr)) + vals = np.asarray(getattr(self, attr)) + if no_longdouble: + vals = reduce_precision_if_extended(vals) + data[attr.lstrip("_")] = vals + + time_array = self.time + if no_longdouble: + time_array = reduce_precision_if_extended(time_array) if kind.lower() == "table": - data["time"] = self.time + data["time"] = time_array ts = Table(data) elif kind.lower() == "timeseries": from astropy.timeseries import TimeSeries - ts = TimeSeries(data=data, time=TimeDelta(self.time * u.s)) + ts = TimeSeries(data=data, time=TimeDelta(time_array * u.s)) else: # pragma: no cover raise ValueError("Invalid kind (accepted: table or timeseries)") @@ -1634,6 +1641,9 @@ def _to_astropy_object(self, kind="table"): "err_dist", ]: if hasattr(self, attr) and getattr(self, attr) is not None: + vals = getattr(self, attr) + if no_longdouble: + vals = reduce_precision_if_extended(vals) ts.meta[attr.lstrip("_")] = getattr(self, attr) return ts diff --git a/stingray/utils.py b/stingray/utils.py index 329eaf5b6..e71eac3a5 100644 --- a/stingray/utils.py +++ b/stingray/utils.py @@ -39,6 +39,12 @@ HAS_PYFFTW = False +try: + np.float128 + HAS_128 = True +except AttributeError: + HAS_128 = False + # If numba is installed, import jit. Otherwise, define an empty decorator with # the same name. try: