Skip to content

Commit

Permalink
Merge pull request #649 from parkma99/main
Browse files Browse the repository at this point in the history
Add type hints to base.py #544
  • Loading branch information
matteobachetti authored Mar 16, 2022
2 parents 26d5446 + 692949a commit a6dec24
Showing 1 changed file with 46 additions and 34 deletions.
80 changes: 46 additions & 34 deletions stingray/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Base classes"""
from __future__ import annotations

from collections.abc import Iterable
import pickle
import warnings
Expand All @@ -9,6 +11,16 @@
from astropy.time import Time, TimeDelta
from astropy.units import Quantity

from typing import TYPE_CHECKING, Type, TypeVar, Union
if TYPE_CHECKING:
from xarray import Dataset
from pandas import DataFrame
from astropy.timeseries import TimeSeries
from astropy.time import TimeDelta
import numpy.typing as npt
TTime = Union[Time, TimeDelta, Quantity, npt.ArrayLike]
Tso = TypeVar("Tso", bound="StingrayObject")


class StingrayObject(object):
"""This base class defines some general-purpose utilities.
Expand All @@ -30,13 +42,13 @@ class StingrayObject(object):
columns of the table/dataframe, otherwise as metadata.
"""

def __init__(cls, *args, **kwargs):
def __init__(cls, *args, **kwargs) -> None:
if not hasattr(cls, "main_array_attr"):
raise RuntimeError(
"A StingrayObject needs to have the main_array_attr attribute specified"
)

def array_attrs(self):
def array_attrs(self) -> list[str]:
"""List the names of the array attributes of the Stingray Object.
By array attributes, we mean the ones with the same size and shape as
Expand All @@ -56,7 +68,7 @@ def array_attrs(self):
)
]

def meta_attrs(self):
def meta_attrs(self) -> list[str]:
"""List the names of the meta attributes of the Stingray Object.
By array attributes, we mean the ones with a different size and shape
Expand All @@ -78,7 +90,7 @@ def meta_attrs(self):
)
]

def get_meta_dict(self):
def get_meta_dict(self) -> dict:
"""Give a dictionary with all non-None meta attrs of the object."""
meta_attrs = self.meta_attrs()
meta_dict = {}
Expand All @@ -88,7 +100,7 @@ def get_meta_dict(self):
meta_dict[key] = val
return meta_dict

def to_astropy_table(self):
def to_astropy_table(self) -> Table:
"""Create an Astropy Table from a ``StingrayObject``
Array attributes (e.g. ``time``, ``pi``, ``energy``, etc. for
Expand All @@ -108,7 +120,7 @@ def to_astropy_table(self):
return ts

@classmethod
def from_astropy_table(cls, ts):
def from_astropy_table(cls: Type[Tso], ts: Table) -> Tso:
"""Create a Stingray Object object from data in an Astropy Table.
The table MUST contain at least a column named like the
Expand All @@ -130,11 +142,11 @@ def from_astropy_table(cls, ts):
array_attrs = ts.colnames

# Set the main attribute first
mainarray = np.array(ts[cls.main_array_attr])
setattr(cls, cls.main_array_attr, mainarray)
mainarray = np.array(ts[cls.main_array_attr]) # type: ignore
setattr(cls, cls.main_array_attr, mainarray) # type: ignore

for attr in array_attrs:
if attr == cls.main_array_attr:
if attr == cls.main_array_attr: # type: ignore
continue
setattr(cls, attr.lower(), np.array(ts[attr]))

Expand All @@ -143,7 +155,7 @@ def from_astropy_table(cls, ts):

return cls

def to_xarray(self):
def to_xarray(self) -> Dataset:
"""Create an ``xarray`` Dataset from a `StingrayObject`.
Array attributes (e.g. ``time``, ``pi``, ``energy``, etc. for
Expand All @@ -165,7 +177,7 @@ def to_xarray(self):
return ts

@classmethod
def from_xarray(cls, ts):
def from_xarray(cls: Type[Tso], ts: Dataset) -> Tso:
"""Create a `StingrayObject` from data in an xarray Dataset.
The dataset MUST contain at least a column named like the
Expand All @@ -180,18 +192,18 @@ def from_xarray(cls, ts):
"""
cls = cls()

if len(ts[cls.main_array_attr]) == 0:
if len(ts[cls.main_array_attr]) == 0: # type: ignore
# return an empty object
return cls

array_attrs = ts.coords

# Set the main attribute first
mainarray = np.array(ts[cls.main_array_attr])
setattr(cls, cls.main_array_attr, mainarray)
mainarray = np.array(ts[cls.main_array_attr]) # type: ignore
setattr(cls, cls.main_array_attr, mainarray) # type: ignore

for attr in array_attrs:
if attr == cls.main_array_attr:
if attr == cls.main_array_attr: # type: ignore
continue
setattr(cls, attr, np.array(ts[attr]))

Expand All @@ -201,7 +213,7 @@ def from_xarray(cls, ts):

return cls

def to_pandas(self):
def to_pandas(self) -> DataFrame:
"""Create a pandas ``DataFrame`` from a :class:`StingrayObject`.
Array attributes (e.g. ``time``, ``pi``, ``energy``, etc. for
Expand All @@ -223,7 +235,7 @@ def to_pandas(self):
return ts

@classmethod
def from_pandas(cls, ts):
def from_pandas(cls: Type[Tso], ts: DataFrame) -> Tso:
"""Create an `StingrayObject` object from data in a pandas DataFrame.
The dataframe MUST contain at least a column named like the
Expand All @@ -246,11 +258,11 @@ def from_pandas(cls, ts):
array_attrs = ts.columns

# Set the main attribute first
mainarray = np.array(ts[cls.main_array_attr])
setattr(cls, cls.main_array_attr, mainarray)
mainarray = np.array(ts[cls.main_array_attr]) # type: ignore
setattr(cls, cls.main_array_attr, mainarray) # type: ignore

for attr in array_attrs:
if attr == cls.main_array_attr:
if attr == cls.main_array_attr: # type: ignore
continue
setattr(cls, attr, np.array(ts[attr]))

Expand All @@ -261,7 +273,7 @@ def from_pandas(cls, ts):
return cls

@classmethod
def read(cls, filename, fmt=None, format_=None):
def read(cls: Type[Tso], filename: str, fmt: str = None, format_=None) -> Tso:
r"""Generic reader for :class`StingrayObject`
Currently supported formats are
Expand Down Expand Up @@ -349,7 +361,7 @@ def read(cls, filename, fmt=None, format_=None):

return cls.from_astropy_table(ts)

def write(self, filename, fmt=None, format_=None):
def write(self, filename: str, fmt: str = None, format_=None) -> None:
"""Generic writer for :class`StingrayObject`
Currently supported formats are
Expand Down Expand Up @@ -405,7 +417,7 @@ def write(self, filename, fmt=None, format_=None):


class StingrayTimeseries(StingrayObject):
def to_astropy_timeseries(self):
def to_astropy_timeseries(self) -> TimeSeries:
"""Save the ``StingrayTimeseries`` to an ``Astropy`` timeseries.
Array attributes (time, pi, energy, etc.) are converted
Expand Down Expand Up @@ -433,8 +445,8 @@ def to_astropy_timeseries(self):
if data == {}:
data = None

if self.time is not None and np.size(self.time) > 0:
times = TimeDelta(self.time * u.s)
if self.time is not None and np.size(self.time) > 0: # type: ignore
times = TimeDelta(self.time * u.s) # type: ignore
ts = TimeSeries(data=data, time=times)
else:
ts = TimeSeries()
Expand All @@ -444,7 +456,7 @@ def to_astropy_timeseries(self):
return ts

@classmethod
def from_astropy_timeseries(cls, ts):
def from_astropy_timeseries(cls, ts: TimeSeries) -> StingrayTimeseries:
"""Create a `StingrayTimeseries` from data in an Astropy TimeSeries
The timeseries has to define at least a column called time,
Expand Down Expand Up @@ -473,7 +485,7 @@ def from_astropy_timeseries(cls, ts):
mjdref = ts.meta["mjdref"]

time, mjdref = interpret_times(time, mjdref)
cls.time = np.asarray(time)
cls.time = np.asarray(time) # type: ignore

array_attrs = ts.colnames
for key, val in ts.meta.items():
Expand All @@ -486,7 +498,7 @@ def from_astropy_timeseries(cls, ts):

return cls

def change_mjdref(self, new_mjdref):
def change_mjdref(self, new_mjdref: float) -> StingrayTimeseries:
"""Change the MJD reference time (MJDREF) of the time series
The times of the time series will be shifted in order to be referred to
Expand All @@ -502,13 +514,13 @@ def change_mjdref(self, new_mjdref):
new_lc : :class:`StingrayTimeseries` object
The new time series, shifted by MJDREF
"""
time_shift = (self.mjdref - new_mjdref) * 86400
time_shift = (self.mjdref - new_mjdref) * 86400 # type: ignore

ts = self.shift(time_shift)
ts.mjdref = new_mjdref
ts.mjdref = new_mjdref # type: ignore
return ts

def shift(self, time_shift):
def shift(self, time_shift: float) -> StingrayTimeseries:
"""Shift the time and the GTIs by the same amount
Parameters
Expand All @@ -524,14 +536,14 @@ def shift(self, time_shift):
"""
ts = copy.deepcopy(self)
ts.time = np.asarray(ts.time) + time_shift
ts.time = np.asarray(ts.time) + time_shift # type: ignore
if hasattr(ts, "gti"):
ts.gti = np.asarray(ts.gti) + time_shift
ts.gti = np.asarray(ts.gti) + time_shift # type: ignore

return ts


def interpret_times(time, mjdref=0):
def interpret_times(time: TTime, mjdref: float = 0) -> tuple[npt.ArrayLike, float]:
"""Understand the format of input times, and return seconds from MJDREF
Parameters
Expand Down

0 comments on commit a6dec24

Please sign in to comment.