diff --git a/stingray/base.py b/stingray/base.py index 7bf8d9347..4947b5d37 100644 --- a/stingray/base.py +++ b/stingray/base.py @@ -1,4 +1,6 @@ """Base classes""" +from __future__ import annotations + from collections.abc import Iterable import pickle import warnings @@ -9,6 +11,16 @@ from astropy.time import Time, TimeDelta from astropy.units import Quantity +from typing import TYPE_CHECKING, Type, TypeVar, Union +if TYPE_CHECKING: + from xarray import Dataset + from pandas import DataFrame + from astropy.timeseries import TimeSeries + from astropy.time import TimeDelta + import numpy.typing as npt + TTime = Union[Time, TimeDelta, Quantity, npt.ArrayLike] + Tso = TypeVar("Tso", bound="StingrayObject") + class StingrayObject(object): """This base class defines some general-purpose utilities. @@ -30,13 +42,13 @@ class StingrayObject(object): columns of the table/dataframe, otherwise as metadata. """ - def __init__(cls, *args, **kwargs): + def __init__(cls, *args, **kwargs) -> None: if not hasattr(cls, "main_array_attr"): raise RuntimeError( "A StingrayObject needs to have the main_array_attr attribute specified" ) - def array_attrs(self): + def array_attrs(self) -> list[str]: """List the names of the array attributes of the Stingray Object. By array attributes, we mean the ones with the same size and shape as @@ -56,7 +68,7 @@ def array_attrs(self): ) ] - def meta_attrs(self): + def meta_attrs(self) -> list[str]: """List the names of the meta attributes of the Stingray Object. By array attributes, we mean the ones with a different size and shape @@ -78,7 +90,7 @@ def meta_attrs(self): ) ] - def get_meta_dict(self): + def get_meta_dict(self) -> dict: """Give a dictionary with all non-None meta attrs of the object.""" meta_attrs = self.meta_attrs() meta_dict = {} @@ -88,7 +100,7 @@ def get_meta_dict(self): meta_dict[key] = val return meta_dict - def to_astropy_table(self): + def to_astropy_table(self) -> Table: """Create an Astropy Table from a ``StingrayObject`` Array attributes (e.g. ``time``, ``pi``, ``energy``, etc. for @@ -108,7 +120,7 @@ def to_astropy_table(self): return ts @classmethod - def from_astropy_table(cls, ts): + def from_astropy_table(cls: Type[Tso], ts: Table) -> Tso: """Create a Stingray Object object from data in an Astropy Table. The table MUST contain at least a column named like the @@ -130,11 +142,11 @@ def from_astropy_table(cls, ts): array_attrs = ts.colnames # Set the main attribute first - mainarray = np.array(ts[cls.main_array_attr]) - setattr(cls, cls.main_array_attr, mainarray) + mainarray = np.array(ts[cls.main_array_attr]) # type: ignore + setattr(cls, cls.main_array_attr, mainarray) # type: ignore for attr in array_attrs: - if attr == cls.main_array_attr: + if attr == cls.main_array_attr: # type: ignore continue setattr(cls, attr.lower(), np.array(ts[attr])) @@ -143,7 +155,7 @@ def from_astropy_table(cls, ts): return cls - def to_xarray(self): + def to_xarray(self) -> Dataset: """Create an ``xarray`` Dataset from a `StingrayObject`. Array attributes (e.g. ``time``, ``pi``, ``energy``, etc. for @@ -165,7 +177,7 @@ def to_xarray(self): return ts @classmethod - def from_xarray(cls, ts): + def from_xarray(cls: Type[Tso], ts: Dataset) -> Tso: """Create a `StingrayObject` from data in an xarray Dataset. The dataset MUST contain at least a column named like the @@ -180,18 +192,18 @@ def from_xarray(cls, ts): """ cls = cls() - if len(ts[cls.main_array_attr]) == 0: + if len(ts[cls.main_array_attr]) == 0: # type: ignore # return an empty object return cls array_attrs = ts.coords # Set the main attribute first - mainarray = np.array(ts[cls.main_array_attr]) - setattr(cls, cls.main_array_attr, mainarray) + mainarray = np.array(ts[cls.main_array_attr]) # type: ignore + setattr(cls, cls.main_array_attr, mainarray) # type: ignore for attr in array_attrs: - if attr == cls.main_array_attr: + if attr == cls.main_array_attr: # type: ignore continue setattr(cls, attr, np.array(ts[attr])) @@ -201,7 +213,7 @@ def from_xarray(cls, ts): return cls - def to_pandas(self): + def to_pandas(self) -> DataFrame: """Create a pandas ``DataFrame`` from a :class:`StingrayObject`. Array attributes (e.g. ``time``, ``pi``, ``energy``, etc. for @@ -223,7 +235,7 @@ def to_pandas(self): return ts @classmethod - def from_pandas(cls, ts): + def from_pandas(cls: Type[Tso], ts: DataFrame) -> Tso: """Create an `StingrayObject` object from data in a pandas DataFrame. The dataframe MUST contain at least a column named like the @@ -246,11 +258,11 @@ def from_pandas(cls, ts): array_attrs = ts.columns # Set the main attribute first - mainarray = np.array(ts[cls.main_array_attr]) - setattr(cls, cls.main_array_attr, mainarray) + mainarray = np.array(ts[cls.main_array_attr]) # type: ignore + setattr(cls, cls.main_array_attr, mainarray) # type: ignore for attr in array_attrs: - if attr == cls.main_array_attr: + if attr == cls.main_array_attr: # type: ignore continue setattr(cls, attr, np.array(ts[attr])) @@ -261,7 +273,7 @@ def from_pandas(cls, ts): return cls @classmethod - def read(cls, filename, fmt=None, format_=None): + def read(cls: Type[Tso], filename: str, fmt: str = None, format_=None) -> Tso: r"""Generic reader for :class`StingrayObject` Currently supported formats are @@ -349,7 +361,7 @@ def read(cls, filename, fmt=None, format_=None): return cls.from_astropy_table(ts) - def write(self, filename, fmt=None, format_=None): + def write(self, filename: str, fmt: str = None, format_=None) -> None: """Generic writer for :class`StingrayObject` Currently supported formats are @@ -405,7 +417,7 @@ def write(self, filename, fmt=None, format_=None): class StingrayTimeseries(StingrayObject): - def to_astropy_timeseries(self): + def to_astropy_timeseries(self) -> TimeSeries: """Save the ``StingrayTimeseries`` to an ``Astropy`` timeseries. Array attributes (time, pi, energy, etc.) are converted @@ -433,8 +445,8 @@ def to_astropy_timeseries(self): if data == {}: data = None - if self.time is not None and np.size(self.time) > 0: - times = TimeDelta(self.time * u.s) + if self.time is not None and np.size(self.time) > 0: # type: ignore + times = TimeDelta(self.time * u.s) # type: ignore ts = TimeSeries(data=data, time=times) else: ts = TimeSeries() @@ -444,7 +456,7 @@ def to_astropy_timeseries(self): return ts @classmethod - def from_astropy_timeseries(cls, ts): + def from_astropy_timeseries(cls, ts: TimeSeries) -> StingrayTimeseries: """Create a `StingrayTimeseries` from data in an Astropy TimeSeries The timeseries has to define at least a column called time, @@ -473,7 +485,7 @@ def from_astropy_timeseries(cls, ts): mjdref = ts.meta["mjdref"] time, mjdref = interpret_times(time, mjdref) - cls.time = np.asarray(time) + cls.time = np.asarray(time) # type: ignore array_attrs = ts.colnames for key, val in ts.meta.items(): @@ -486,7 +498,7 @@ def from_astropy_timeseries(cls, ts): return cls - def change_mjdref(self, new_mjdref): + def change_mjdref(self, new_mjdref: float) -> StingrayTimeseries: """Change the MJD reference time (MJDREF) of the time series The times of the time series will be shifted in order to be referred to @@ -502,13 +514,13 @@ def change_mjdref(self, new_mjdref): new_lc : :class:`StingrayTimeseries` object The new time series, shifted by MJDREF """ - time_shift = (self.mjdref - new_mjdref) * 86400 + time_shift = (self.mjdref - new_mjdref) * 86400 # type: ignore ts = self.shift(time_shift) - ts.mjdref = new_mjdref + ts.mjdref = new_mjdref # type: ignore return ts - def shift(self, time_shift): + def shift(self, time_shift: float) -> StingrayTimeseries: """Shift the time and the GTIs by the same amount Parameters @@ -524,14 +536,14 @@ def shift(self, time_shift): """ ts = copy.deepcopy(self) - ts.time = np.asarray(ts.time) + time_shift + ts.time = np.asarray(ts.time) + time_shift # type: ignore if hasattr(ts, "gti"): - ts.gti = np.asarray(ts.gti) + time_shift + ts.gti = np.asarray(ts.gti) + time_shift # type: ignore return ts -def interpret_times(time, mjdref=0): +def interpret_times(time: TTime, mjdref: float = 0) -> tuple[npt.ArrayLike, float]: """Understand the format of input times, and return seconds from MJDREF Parameters