Skip to content

Commit

Permalink
Fix what is an array attr and what is not
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobachetti committed Sep 18, 2023
1 parent 93513f3 commit f2209a9
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 26 deletions.
39 changes: 18 additions & 21 deletions stingray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class StingrayObject(object):
columns of the table/dataframe, otherwise as metadata.
"""

not_array_attr: list = []

def __init__(cls, *args, **kwargs) -> None:
if not hasattr(cls, "main_array_attr"):
raise RuntimeError(
Expand All @@ -71,8 +73,11 @@ def array_attrs(self) -> list[str]:
for attr in dir(self)
if (
isinstance(getattr(self, attr), Iterable)
and np.shape(getattr(self, attr)) == np.shape(main_attr)
and not attr == self.main_array_attr
and not attr in self.not_array_attr
and not isinstance(getattr(self, attr), str)
and not attr.startswith("_")
and np.shape(getattr(self, attr))[0] == np.shape(main_attr)[0]
)
]

Expand All @@ -92,8 +97,9 @@ def internal_array_attrs(self) -> list[str]:
for attr in dir(self)
if (
isinstance(getattr(self, attr), Iterable)
and np.shape(getattr(self, attr)) == np.shape(main_attr)
and not isinstance(getattr(self, attr), str)
and attr.startswith("_")
and np.shape(getattr(self, attr))[0] == np.shape(main_attr)[0]
)
]

Expand All @@ -103,7 +109,7 @@ def meta_attrs(self) -> list[str]:
By array attributes, we mean the ones with a different size and shape
than ``main_array_attr`` (e.g. ``time`` in ``EventList``)
"""
array_attrs = self.array_attrs()
array_attrs = self.array_attrs() + [self.main_array_attr]
return [
attr
for attr in dir(self)
Expand Down Expand Up @@ -153,7 +159,6 @@ def __eq__(self, other_ts):

def _default_operated_attrs(self):
operated_attrs = [attr for attr in self.array_attrs() if not attr.endswith("_err")]
operated_attrs.remove(self.main_array_attr)
return operated_attrs

def _default_error_attrs(self):
Expand All @@ -177,7 +182,7 @@ def to_astropy_table(self) -> Table:
(``mjdref``, ``gti``, etc.) are saved into the ``meta`` dictionary.
"""
data = {}
array_attrs = self.array_attrs()
array_attrs = self.array_attrs() + [self.main_array_attr]

for attr in array_attrs:
data[attr] = np.asarray(getattr(self, attr))
Expand Down Expand Up @@ -234,7 +239,7 @@ def to_xarray(self) -> Dataset:
from xarray import Dataset

data = {}
array_attrs = self.array_attrs()
array_attrs = self.array_attrs() + [self.main_array_attr]

for attr in array_attrs:
data[attr] = np.asarray(getattr(self, attr))
Expand Down Expand Up @@ -292,7 +297,7 @@ def to_pandas(self) -> DataFrame:
from pandas import DataFrame

data = {}
array_attrs = self.array_attrs()
array_attrs = self.array_attrs() + [self.main_array_attr]

for attr in array_attrs:
data[attr] = np.asarray(getattr(self, attr))
Expand Down Expand Up @@ -492,7 +497,7 @@ def apply_mask(self, mask: npt.ArrayLike, inplace: bool = False, filtered_attrs:
if filtered_attrs is None:
filtered_attrs = all_attrs
if self.main_array_attr not in filtered_attrs:
filtered_attrs.append(self.main_array_attrs)
filtered_attrs.append(self.main_array_attr)

if inplace:
new_ts = self
Expand Down Expand Up @@ -691,14 +696,15 @@ def __getitem__(self, index):
for attr in self.meta_attrs():
setattr(new_ts, attr, copy.deepcopy(getattr(self, attr)))

for attr in self.array_attrs():
for attr in self.array_attrs() + [self.main_array_attr]:
setattr(new_ts, attr, getattr(self, attr)[start:stop:step])

return new_ts


class StingrayTimeseries(StingrayObject):
main_array_attr = "time"
not_array_attr = "gti"

def __init__(
self,
Expand Down Expand Up @@ -730,28 +736,19 @@ def __init__(
self.time = np.asarray(time)
else:
self.time = np.asarray(time, dtype=np.longdouble)
self.ncounts = self.time.size
else:
self.time = None

for kw in other_kw:
setattr(self, kw, other_kw[kw])
for kw in array_attrs:
new_arr = np.asarray(array_attrs[kw])
if self.time.size != new_arr.size:
if self.time.shape[0] != new_arr.shape[0]:
raise ValueError(f"Lengths of time and {kw} must be equal.")
setattr(self, kw, new_arr)

@property
def gti(self):
if self._gti is None:
self._gti = np.asarray([[self.time[0] - 0.5 * self.dt, self.time[-1] + 0.5 * self.dt]])
return self._gti

@gti.setter
def gti(self, value):
value = np.asarray(value) if value is not None else None
self._gti = value
if gti is None and self.time is not None and np.size(self.time) > 0:
self.gti = np.asarray([[self.time[0] - 0.5 * self.dt, self.time[-1] + 0.5 * self.dt]])

def apply_gtis(self, new_gti=None, inplace: bool = True):
"""
Expand Down
4 changes: 2 additions & 2 deletions stingray/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def _get_all_array_attrs(objs):
all_gti_lists = []

for obj in all_objs:
if obj.gti is None and len(obj.time) > 0:
if obj.gti is None and obj.time is not None and len(obj.time) > 0:
obj.gti = assign_value_if_none(
obj.gti,
np.asarray([[obj.time[0] - obj.dt / 2, obj.time[-1] + obj.dt / 2]]),
Expand Down Expand Up @@ -716,7 +716,7 @@ def apply_mask(self, mask, inplace=False):
>>> evt is newev1
True
"""
array_attrs = self.array_attrs()
array_attrs = self.array_attrs() + ["time"]

if inplace:
new_ev = self
Expand Down
8 changes: 5 additions & 3 deletions stingray/lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,11 @@ def mask(self):

@property
def n(self):
if self._n is None:
self._n = self.counts.shape[0]
return self._n
return self.time.shape[0]

@n.setter
def n(self, value):
pass

@property
def meanrate(self):
Expand Down
16 changes: 16 additions & 0 deletions stingray/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,22 @@ def test_apply_mask(self):
assert ts is newts1
assert ts is not newts0

def test_what_is_array_and_what_is_not(self):
"""Test that array_attrs are not confused with other attributes.
In particular, time, gti and panesapa have the same length. Verify that panesapa is considered
an array attribute, but not gti."""
ts = StingrayTimeseries(
[0, 3],
gti=[[0.5, 1.5], [2.5, 3.5]],
array_attrs=dict(panesapa=np.asarray([[41, 25], [98, 3]])),
dt=1,
)
array_attrs = ts.array_attrs()
assert "panesapa" in array_attrs
assert "gti" not in array_attrs
assert "time" not in array_attrs

def test_operations(self):
time = [5, 10, 15]
count1 = [300, 100, 400]
Expand Down

0 comments on commit f2209a9

Please sign in to comment.