From c9d97fce9ef7895cfe8811dfacd832e29b05360a Mon Sep 17 00:00:00 2001 From: parkma99 Date: Tue, 8 Mar 2022 22:41:38 +0800 Subject: [PATCH 1/5] Add type hints to base.py --- stingray/base.py | 44 +++++++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/stingray/base.py b/stingray/base.py index 7bf8d9347..23743bc47 100644 --- a/stingray/base.py +++ b/stingray/base.py @@ -1,4 +1,6 @@ """Base classes""" +from __future__ import annotations + from collections.abc import Iterable import pickle import warnings @@ -9,6 +11,14 @@ from astropy.time import Time, TimeDelta from astropy.units import Quantity +from typing import TYPE_CHECKING, Union +if TYPE_CHECKING: + from xarray import Dataset + from pandas import DataFrame + from astropy.timeseries import TimeSeries + from astropy.time import TimeDelta + TTime = Union[Time,TimeDelta,Quantity,np.array] + class StingrayObject(object): """This base class defines some general-purpose utilities. @@ -30,13 +40,13 @@ class StingrayObject(object): columns of the table/dataframe, otherwise as metadata. """ - def __init__(cls, *args, **kwargs): + def __init__(cls, *args, **kwargs) -> None: if not hasattr(cls, "main_array_attr"): raise RuntimeError( "A StingrayObject needs to have the main_array_attr attribute specified" ) - def array_attrs(self): + def array_attrs(self) -> list[str]: """List the names of the array attributes of the Stingray Object. By array attributes, we mean the ones with the same size and shape as @@ -56,7 +66,7 @@ def array_attrs(self): ) ] - def meta_attrs(self): + def meta_attrs(self) -> list[str]: """List the names of the meta attributes of the Stingray Object. By array attributes, we mean the ones with a different size and shape @@ -78,7 +88,7 @@ def meta_attrs(self): ) ] - def get_meta_dict(self): + def get_meta_dict(self) -> dict: """Give a dictionary with all non-None meta attrs of the object.""" meta_attrs = self.meta_attrs() meta_dict = {} @@ -88,7 +98,7 @@ def get_meta_dict(self): meta_dict[key] = val return meta_dict - def to_astropy_table(self): + def to_astropy_table(self) -> Table: """Create an Astropy Table from a ``StingrayObject`` Array attributes (e.g. ``time``, ``pi``, ``energy``, etc. for @@ -108,7 +118,7 @@ def to_astropy_table(self): return ts @classmethod - def from_astropy_table(cls, ts): + def from_astropy_table(cls, ts: Table) -> StingrayObject: """Create a Stingray Object object from data in an Astropy Table. The table MUST contain at least a column named like the @@ -143,7 +153,7 @@ def from_astropy_table(cls, ts): return cls - def to_xarray(self): + def to_xarray(self) -> Dataset: """Create an ``xarray`` Dataset from a `StingrayObject`. Array attributes (e.g. ``time``, ``pi``, ``energy``, etc. for @@ -165,7 +175,7 @@ def to_xarray(self): return ts @classmethod - def from_xarray(cls, ts): + def from_xarray(cls, ts: Dataset) -> StingrayObject: """Create a `StingrayObject` from data in an xarray Dataset. The dataset MUST contain at least a column named like the @@ -201,7 +211,7 @@ def from_xarray(cls, ts): return cls - def to_pandas(self): + def to_pandas(self) -> DataFrame: """Create a pandas ``DataFrame`` from a :class:`StingrayObject`. Array attributes (e.g. ``time``, ``pi``, ``energy``, etc. for @@ -223,7 +233,7 @@ def to_pandas(self): return ts @classmethod - def from_pandas(cls, ts): + def from_pandas(cls, ts: DataFrame) -> StingrayObject: """Create an `StingrayObject` object from data in a pandas DataFrame. The dataframe MUST contain at least a column named like the @@ -261,7 +271,7 @@ def from_pandas(cls, ts): return cls @classmethod - def read(cls, filename, fmt=None, format_=None): + def read(cls, filename: str, fmt: str=None, format_=None) -> StingrayObject: r"""Generic reader for :class`StingrayObject` Currently supported formats are @@ -349,7 +359,7 @@ def read(cls, filename, fmt=None, format_=None): return cls.from_astropy_table(ts) - def write(self, filename, fmt=None, format_=None): + def write(self, filename: str, fmt: str=None, format_=None) -> None: """Generic writer for :class`StingrayObject` Currently supported formats are @@ -405,7 +415,7 @@ def write(self, filename, fmt=None, format_=None): class StingrayTimeseries(StingrayObject): - def to_astropy_timeseries(self): + def to_astropy_timeseries(self) -> StingrayTimeseries: """Save the ``StingrayTimeseries`` to an ``Astropy`` timeseries. Array attributes (time, pi, energy, etc.) are converted @@ -444,7 +454,7 @@ def to_astropy_timeseries(self): return ts @classmethod - def from_astropy_timeseries(cls, ts): + def from_astropy_timeseries(cls, ts: TimeSeries) -> StingrayTimeseries: """Create a `StingrayTimeseries` from data in an Astropy TimeSeries The timeseries has to define at least a column called time, @@ -486,7 +496,7 @@ def from_astropy_timeseries(cls, ts): return cls - def change_mjdref(self, new_mjdref): + def change_mjdref(self, new_mjdref: float) -> StingrayTimeseries: """Change the MJD reference time (MJDREF) of the time series The times of the time series will be shifted in order to be referred to @@ -508,7 +518,7 @@ def change_mjdref(self, new_mjdref): ts.mjdref = new_mjdref return ts - def shift(self, time_shift): + def shift(self, time_shift: float) -> StingrayTimeseries: """Shift the time and the GTIs by the same amount Parameters @@ -531,7 +541,7 @@ def shift(self, time_shift): return ts -def interpret_times(time, mjdref=0): +def interpret_times(time: TTime, mjdref: float=0) -> tuple[np.array,float]: """Understand the format of input times, and return seconds from MJDREF Parameters From 51d9ecc056ddaebc6b66b88e7d537e752f78c678 Mon Sep 17 00:00:00 2001 From: parkma99 Date: Tue, 8 Mar 2022 23:03:50 +0800 Subject: [PATCH 2/5] PEP8 check --- stingray/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/stingray/base.py b/stingray/base.py index 23743bc47..2bc6e6808 100644 --- a/stingray/base.py +++ b/stingray/base.py @@ -17,7 +17,7 @@ from pandas import DataFrame from astropy.timeseries import TimeSeries from astropy.time import TimeDelta - TTime = Union[Time,TimeDelta,Quantity,np.array] + TTime = Union[Time, TimeDelta, Quantity, np.array] class StingrayObject(object): @@ -271,7 +271,7 @@ def from_pandas(cls, ts: DataFrame) -> StingrayObject: return cls @classmethod - def read(cls, filename: str, fmt: str=None, format_=None) -> StingrayObject: + def read(cls, filename: str, fmt: str = None, format_=None) -> StingrayObject: r"""Generic reader for :class`StingrayObject` Currently supported formats are @@ -359,7 +359,7 @@ def read(cls, filename: str, fmt: str=None, format_=None) -> StingrayObject: return cls.from_astropy_table(ts) - def write(self, filename: str, fmt: str=None, format_=None) -> None: + def write(self, filename: str, fmt: str = None, format_=None) -> None: """Generic writer for :class`StingrayObject` Currently supported formats are @@ -541,7 +541,7 @@ def shift(self, time_shift: float) -> StingrayTimeseries: return ts -def interpret_times(time: TTime, mjdref: float=0) -> tuple[np.array,float]: +def interpret_times(time: TTime, mjdref: float = 0) -> tuple[np.array, float]: """Understand the format of input times, and return seconds from MJDREF Parameters From 608af2adef256655d907bfc195071a81e5bfd6ef Mon Sep 17 00:00:00 2001 From: parkma99 Date: Wed, 9 Mar 2022 19:15:15 +0800 Subject: [PATCH 3/5] fix some type error --- stingray/base.py | 52 +++++++++++++++++++++++++----------------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/stingray/base.py b/stingray/base.py index 2bc6e6808..53ea5e7be 100644 --- a/stingray/base.py +++ b/stingray/base.py @@ -11,13 +11,15 @@ from astropy.time import Time, TimeDelta from astropy.units import Quantity -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Type, TypeVar, Union if TYPE_CHECKING: from xarray import Dataset from pandas import DataFrame from astropy.timeseries import TimeSeries from astropy.time import TimeDelta - TTime = Union[Time, TimeDelta, Quantity, np.array] + import numpy.typing as npt + TTime = Union[Time, TimeDelta, Quantity, npt.ArrayLike] + Tso = TypeVar("Tso", bound = "StingrayObject") class StingrayObject(object): @@ -118,7 +120,7 @@ def to_astropy_table(self) -> Table: return ts @classmethod - def from_astropy_table(cls, ts: Table) -> StingrayObject: + def from_astropy_table(cls: Type[Tso], ts: Table) -> Tso: """Create a Stingray Object object from data in an Astropy Table. The table MUST contain at least a column named like the @@ -140,11 +142,11 @@ def from_astropy_table(cls, ts: Table) -> StingrayObject: array_attrs = ts.colnames # Set the main attribute first - mainarray = np.array(ts[cls.main_array_attr]) - setattr(cls, cls.main_array_attr, mainarray) + mainarray = np.array(ts[cls.main_array_attr]) # type: ignore + setattr(cls, cls.main_array_attr, mainarray) # type: ignore for attr in array_attrs: - if attr == cls.main_array_attr: + if attr == cls.main_array_attr: # type: ignore continue setattr(cls, attr.lower(), np.array(ts[attr])) @@ -175,7 +177,7 @@ def to_xarray(self) -> Dataset: return ts @classmethod - def from_xarray(cls, ts: Dataset) -> StingrayObject: + def from_xarray(cls: Type[Tso], ts: Dataset) -> Tso: """Create a `StingrayObject` from data in an xarray Dataset. The dataset MUST contain at least a column named like the @@ -190,18 +192,18 @@ def from_xarray(cls, ts: Dataset) -> StingrayObject: """ cls = cls() - if len(ts[cls.main_array_attr]) == 0: + if len(ts[cls.main_array_attr]) == 0: # type: ignore # return an empty object return cls array_attrs = ts.coords # Set the main attribute first - mainarray = np.array(ts[cls.main_array_attr]) - setattr(cls, cls.main_array_attr, mainarray) + mainarray = np.array(ts[cls.main_array_attr]) # type: ignore + setattr(cls, cls.main_array_attr, mainarray) # type: ignore for attr in array_attrs: - if attr == cls.main_array_attr: + if attr == cls.main_array_attr: # type: ignore continue setattr(cls, attr, np.array(ts[attr])) @@ -233,7 +235,7 @@ def to_pandas(self) -> DataFrame: return ts @classmethod - def from_pandas(cls, ts: DataFrame) -> StingrayObject: + def from_pandas(cls: Type[Tso], ts: DataFrame) -> Tso: """Create an `StingrayObject` object from data in a pandas DataFrame. The dataframe MUST contain at least a column named like the @@ -256,11 +258,11 @@ def from_pandas(cls, ts: DataFrame) -> StingrayObject: array_attrs = ts.columns # Set the main attribute first - mainarray = np.array(ts[cls.main_array_attr]) - setattr(cls, cls.main_array_attr, mainarray) + mainarray = np.array(ts[cls.main_array_attr]) # type: ignore + setattr(cls, cls.main_array_attr, mainarray) # type: ignore for attr in array_attrs: - if attr == cls.main_array_attr: + if attr == cls.main_array_attr: # type: ignore continue setattr(cls, attr, np.array(ts[attr])) @@ -271,7 +273,7 @@ def from_pandas(cls, ts: DataFrame) -> StingrayObject: return cls @classmethod - def read(cls, filename: str, fmt: str = None, format_=None) -> StingrayObject: + def read(cls: Type[Tso], filename: str, fmt: str = None, format_=None) -> Tso: r"""Generic reader for :class`StingrayObject` Currently supported formats are @@ -415,7 +417,7 @@ def write(self, filename: str, fmt: str = None, format_=None) -> None: class StingrayTimeseries(StingrayObject): - def to_astropy_timeseries(self) -> StingrayTimeseries: + def to_astropy_timeseries(self) -> TimeSeries: """Save the ``StingrayTimeseries`` to an ``Astropy`` timeseries. Array attributes (time, pi, energy, etc.) are converted @@ -443,8 +445,8 @@ def to_astropy_timeseries(self) -> StingrayTimeseries: if data == {}: data = None - if self.time is not None and np.size(self.time) > 0: - times = TimeDelta(self.time * u.s) + if self.time is not None and np.size(self.time) > 0: # type: ignore + times = TimeDelta(self.time * u.s) # type: ignore ts = TimeSeries(data=data, time=times) else: ts = TimeSeries() @@ -483,7 +485,7 @@ def from_astropy_timeseries(cls, ts: TimeSeries) -> StingrayTimeseries: mjdref = ts.meta["mjdref"] time, mjdref = interpret_times(time, mjdref) - cls.time = np.asarray(time) + cls.time = np.asarray(time) # type: ignore array_attrs = ts.colnames for key, val in ts.meta.items(): @@ -512,10 +514,10 @@ def change_mjdref(self, new_mjdref: float) -> StingrayTimeseries: new_lc : :class:`StingrayTimeseries` object The new time series, shifted by MJDREF """ - time_shift = (self.mjdref - new_mjdref) * 86400 + time_shift = (self.mjdref - new_mjdref) * 86400 # type: ignore ts = self.shift(time_shift) - ts.mjdref = new_mjdref + ts.mjdref = new_mjdref # type: ignore return ts def shift(self, time_shift: float) -> StingrayTimeseries: @@ -534,14 +536,14 @@ def shift(self, time_shift: float) -> StingrayTimeseries: """ ts = copy.deepcopy(self) - ts.time = np.asarray(ts.time) + time_shift + ts.time = np.asarray(ts.time) + time_shift # type: ignore if hasattr(ts, "gti"): - ts.gti = np.asarray(ts.gti) + time_shift + ts.gti = np.asarray(ts.gti) + time_shift # type: ignore return ts -def interpret_times(time: TTime, mjdref: float = 0) -> tuple[np.array, float]: +def interpret_times(time: TTime, mjdref: float = 0) -> tuple[npt.ArrayLike, float]: """Understand the format of input times, and return seconds from MJDREF Parameters From 212fd2b44e170de41e300ab34aca83cfcbed7fba Mon Sep 17 00:00:00 2001 From: parkma99 Date: Wed, 9 Mar 2022 19:31:33 +0800 Subject: [PATCH 4/5] fix pep8 --- stingray/base.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/stingray/base.py b/stingray/base.py index 53ea5e7be..9476b5fec 100644 --- a/stingray/base.py +++ b/stingray/base.py @@ -19,7 +19,7 @@ from astropy.time import TimeDelta import numpy.typing as npt TTime = Union[Time, TimeDelta, Quantity, npt.ArrayLike] - Tso = TypeVar("Tso", bound = "StingrayObject") + Tso=TypeVar("Tso", bound = "StingrayObject") class StingrayObject(object): @@ -142,11 +142,11 @@ def from_astropy_table(cls: Type[Tso], ts: Table) -> Tso: array_attrs = ts.colnames # Set the main attribute first - mainarray = np.array(ts[cls.main_array_attr]) # type: ignore - setattr(cls, cls.main_array_attr, mainarray) # type: ignore + mainarray = np.array(ts[cls.main_array_attr]) # type: ignore + setattr(cls, cls.main_array_attr, mainarray) # type: ignore for attr in array_attrs: - if attr == cls.main_array_attr: # type: ignore + if attr == cls.main_array_attr: # type: ignore continue setattr(cls, attr.lower(), np.array(ts[attr])) @@ -192,18 +192,18 @@ def from_xarray(cls: Type[Tso], ts: Dataset) -> Tso: """ cls = cls() - if len(ts[cls.main_array_attr]) == 0: # type: ignore + if len(ts[cls.main_array_attr]) == 0: # type: ignore # return an empty object return cls array_attrs = ts.coords # Set the main attribute first - mainarray = np.array(ts[cls.main_array_attr]) # type: ignore - setattr(cls, cls.main_array_attr, mainarray) # type: ignore + mainarray = np.array(ts[cls.main_array_attr]) # type: ignore + setattr(cls, cls.main_array_attr, mainarray) # type: ignore for attr in array_attrs: - if attr == cls.main_array_attr: # type: ignore + if attr == cls.main_array_attr: # type: ignore continue setattr(cls, attr, np.array(ts[attr])) @@ -258,11 +258,11 @@ def from_pandas(cls: Type[Tso], ts: DataFrame) -> Tso: array_attrs = ts.columns # Set the main attribute first - mainarray = np.array(ts[cls.main_array_attr]) # type: ignore - setattr(cls, cls.main_array_attr, mainarray) # type: ignore + mainarray = np.array(ts[cls.main_array_attr]) # type: ignore + setattr(cls, cls.main_array_attr, mainarray) # type: ignore for attr in array_attrs: - if attr == cls.main_array_attr: # type: ignore + if attr == cls.main_array_attr: # type: ignore continue setattr(cls, attr, np.array(ts[attr])) @@ -445,8 +445,8 @@ def to_astropy_timeseries(self) -> TimeSeries: if data == {}: data = None - if self.time is not None and np.size(self.time) > 0: # type: ignore - times = TimeDelta(self.time * u.s) # type: ignore + if self.time is not None and np.size(self.time) > 0: # type: ignore + times = TimeDelta(self.time * u.s) # type: ignore ts = TimeSeries(data=data, time=times) else: ts = TimeSeries() @@ -485,7 +485,7 @@ def from_astropy_timeseries(cls, ts: TimeSeries) -> StingrayTimeseries: mjdref = ts.meta["mjdref"] time, mjdref = interpret_times(time, mjdref) - cls.time = np.asarray(time) # type: ignore + cls.time = np.asarray(time) # type: ignore array_attrs = ts.colnames for key, val in ts.meta.items(): @@ -514,10 +514,10 @@ def change_mjdref(self, new_mjdref: float) -> StingrayTimeseries: new_lc : :class:`StingrayTimeseries` object The new time series, shifted by MJDREF """ - time_shift = (self.mjdref - new_mjdref) * 86400 # type: ignore + time_shift = (self.mjdref - new_mjdref) * 86400 # type: ignore ts = self.shift(time_shift) - ts.mjdref = new_mjdref # type: ignore + ts.mjdref = new_mjdref # type: ignore return ts def shift(self, time_shift: float) -> StingrayTimeseries: @@ -536,9 +536,9 @@ def shift(self, time_shift: float) -> StingrayTimeseries: """ ts = copy.deepcopy(self) - ts.time = np.asarray(ts.time) + time_shift # type: ignore + ts.time = np.asarray(ts.time) + time_shift # type: ignore if hasattr(ts, "gti"): - ts.gti = np.asarray(ts.gti) + time_shift # type: ignore + ts.gti = np.asarray(ts.gti) + time_shift # type: ignore return ts From 692949ae7b5146380f89486c0902c8ddf76d3297 Mon Sep 17 00:00:00 2001 From: parkma99 Date: Wed, 9 Mar 2022 19:33:34 +0800 Subject: [PATCH 5/5] fix pep8 --- stingray/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stingray/base.py b/stingray/base.py index 9476b5fec..4947b5d37 100644 --- a/stingray/base.py +++ b/stingray/base.py @@ -19,7 +19,7 @@ from astropy.time import TimeDelta import numpy.typing as npt TTime = Union[Time, TimeDelta, Quantity, npt.ArrayLike] - Tso=TypeVar("Tso", bound = "StingrayObject") + Tso = TypeVar("Tso", bound="StingrayObject") class StingrayObject(object):