Skip to content

Commit

Permalink
Merge pull request #824 from StingraySoftware/searchsorted_on_memmap
Browse files Browse the repository at this point in the history
Avoid copy in memory of memmaps
  • Loading branch information
matteobachetti authored May 7, 2024
2 parents 342f38e + 0e96476 commit 379ed3a
Show file tree
Hide file tree
Showing 28 changed files with 229 additions and 190 deletions.
1 change: 1 addition & 0 deletions docs/changes/824.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Substitute np.asarray with np.asanyarray everywhere, to avoid copying memory maps into memory if possible
55 changes: 29 additions & 26 deletions stingray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(cls, *args, **kwargs) -> None:
def main_array_length(self):
if getattr(self, self.main_array_attr, None) is None:
return 0
return np.shape(np.asarray(getattr(self, self.main_array_attr)))[0]
return np.shape(np.asanyarray(getattr(self, self.main_array_attr)))[0]

def data_attributes(self) -> list[str]:
"""Clean up the list of attributes, only giving out those pointing to data.
Expand All @@ -130,7 +130,7 @@ def data_attributes(self) -> list[str]:
and not isinstance(getattr(self.__class__, attr, None), property)
and not callable(value := getattr(self, attr))
and not isinstance(value, StingrayObject)
and not np.asarray(value).dtype == "O"
and not np.asanyarray(value).dtype == "O"
)
]

Expand Down Expand Up @@ -368,7 +368,7 @@ def to_astropy_table(self, no_longdouble=False) -> Table:
array_attrs = self.array_attrs() + [self.main_array_attr] + self.internal_array_attrs()

for attr in array_attrs:
vals = np.asarray(getattr(self, attr))
vals = np.asanyarray(getattr(self, attr))
if no_longdouble:
vals = reduce_precision_if_extended(vals)
data[attr] = vals
Expand Down Expand Up @@ -455,7 +455,7 @@ def to_xarray(self) -> Dataset:
array_attrs = self.array_attrs() + [self.main_array_attr] + self.internal_array_attrs()

for attr in array_attrs:
new_data = np.asarray(getattr(self, attr))
new_data = np.asanyarray(getattr(self, attr))
ndim = len(np.shape(new_data))
if ndim > 1:
new_data = ([attr + f"_dim{i}" for i in range(ndim)], new_data)
Expand Down Expand Up @@ -520,7 +520,7 @@ def to_pandas(self) -> DataFrame:
array_attrs = self.array_attrs() + [self.main_array_attr] + self.internal_array_attrs()

for attr in array_attrs:
values = np.asarray(getattr(self, attr))
values = np.asanyarray(getattr(self, attr))
ndim = len(np.shape(values))
if ndim > 1:
local_data = make_nd_into_arrays(values, attr)
Expand Down Expand Up @@ -758,21 +758,21 @@ def apply_mask(self, mask: npt.ArrayLike, inplace: bool = False, filtered_attrs:
setattr(
new_ts,
"_" + self.main_array_attr,
copy.deepcopy(np.asarray(getattr(self, self.main_array_attr))[mask]),
copy.deepcopy(np.asanyarray(getattr(self, self.main_array_attr))[mask]),
)
else:
setattr(
new_ts,
self.main_array_attr,
copy.deepcopy(np.asarray(getattr(self, self.main_array_attr))[mask]),
copy.deepcopy(np.asanyarray(getattr(self, self.main_array_attr))[mask]),
)

for attr in all_attrs:
if attr not in filtered_attrs:
# Eliminate all unfiltered attributes
setattr(new_ts, attr, None)
else:
setattr(new_ts, attr, copy.deepcopy(np.asarray(getattr(self, attr))[mask]))
setattr(new_ts, attr, copy.deepcopy(np.asanyarray(getattr(self, attr))[mask]))
return new_ts

def _operation_with_other_obj(
Expand Down Expand Up @@ -1030,7 +1030,7 @@ def __neg__(self):

ts_new = copy.deepcopy(self)
for attr in self._default_operated_attrs():
setattr(ts_new, attr, -np.asarray(getattr(self, attr)))
setattr(ts_new, attr, -np.asanyarray(getattr(self, attr)))

return ts_new

Expand Down Expand Up @@ -1215,7 +1215,7 @@ def __init__(
for kw in other_kw:
setattr(self, kw, other_kw[kw])
for kw in array_attrs:
new_arr = np.asarray(array_attrs[kw])
new_arr = np.asanyarray(array_attrs[kw])
if self.time.shape[0] != new_arr.shape[0]:
raise ValueError(f"Lengths of time and {kw} must be equal.")
setattr(self, kw, new_arr)
Expand Down Expand Up @@ -1246,15 +1246,15 @@ def gti(self):
dt1 = self.dt[-1]
else:
dt0 = dt1 = self.dt
self._gti = np.asarray([[self._time[0] - dt0 / 2, self._time[-1] + dt1 / 2]])
self._gti = np.asanyarray([[self._time[0] - dt0 / 2, self._time[-1] + dt1 / 2]])
return self._gti

@gti.setter
def gti(self, value):
if value is None:
self._gti = None
return
value = np.asarray(value)
value = np.asanyarray(value)
self._gti = value
self._mask = None

Expand All @@ -1278,15 +1278,15 @@ def _set_times(self, time, high_precision=False):
return
time, _ = interpret_times(time, self.mjdref)
if not high_precision:
self._time = np.asarray(time)
self._time = np.asanyarray(time)
else:
self._time = np.asarray(time, dtype=np.longdouble)
self._time = np.asanyarray(time, dtype=np.longdouble)

def __str__(self) -> str:
"""Return a string representation of the object."""
return self.pretty_print(
attrs_to_apply=["gti", "time", "tstart", "tseg", "tstop"],
func_to_apply=lambda x: (np.asarray(x) / 86400 + self.mjdref, "MJD"),
func_to_apply=lambda x: (np.asanyarray(x) / 86400 + self.mjdref, "MJD"),
attrs_to_discard=["_mask", "header"],
)

Expand Down Expand Up @@ -1318,7 +1318,7 @@ def _validate_and_format(self, value, attr_name, compare_to_attr):
"""
if value is None:
return None
value = np.asarray(value)
value = np.asanyarray(value)
if len(value.shape) < 1:
raise ValueError(f"{attr_name} array must be at least 1D")
# If the attribute we compare it with is the same and it is currently None, we assign it
Expand Down Expand Up @@ -1446,7 +1446,7 @@ def to_astropy_timeseries(self) -> TimeSeries:
for attr in array_attrs:
if attr == "time":
continue
data[attr] = np.asarray(getattr(self, attr))
data[attr] = np.asanyarray(getattr(self, attr))

if data == {}:
data = None
Expand Down Expand Up @@ -1489,7 +1489,7 @@ def from_astropy_timeseries(cls, ts: TimeSeries) -> StingrayTimeseries:

new_cls = cls()
time, mjdref = interpret_times(time, mjdref)
new_cls.time = np.asarray(time) # type: ignore
new_cls.time = np.asanyarray(time) # type: ignore

array_attrs = ts.colnames
for key, val in ts.meta.items():
Expand All @@ -1498,7 +1498,7 @@ def from_astropy_timeseries(cls, ts: TimeSeries) -> StingrayTimeseries:
for attr in array_attrs:
if attr == "time":
continue
setattr(new_cls, attr, np.asarray(ts[attr]))
setattr(new_cls, attr, np.asanyarray(ts[attr]))

return new_cls

Expand Down Expand Up @@ -1553,11 +1553,11 @@ def shift(self, time_shift: float, inplace=False) -> StingrayTimeseries:
ts = self
else:
ts = copy.deepcopy(self)
ts.time = np.asarray(ts.time) + time_shift # type: ignore
ts.time = np.asanyarray(ts.time) + time_shift # type: ignore
# Pay attention here: if the GTIs are created dynamically while we
# access the property,
if ts._gti is not None:
ts._gti = np.asarray(ts._gti) + time_shift # type: ignore
ts._gti = np.asanyarray(ts._gti) + time_shift # type: ignore

return ts

Expand Down Expand Up @@ -1718,7 +1718,9 @@ def __getitem__(self, index):
delta_gti_start = new_ts.dt[0] * 0.5
delta_gti_stop = new_ts.dt[-1] * 0.5

new_gti = np.asarray([[new_ts.time[0] - delta_gti_start, new_ts.time[-1] + delta_gti_stop]])
new_gti = np.asanyarray(
[[new_ts.time[0] - delta_gti_start, new_ts.time[-1] + delta_gti_stop]]
)
if step > 1 and delta_gti_start > 0:
new_gt1 = np.array(list(zip(new_ts.time - new_ts.dt / 2, new_ts.time + new_ts.dt / 2)))
new_gti = cross_two_gtis(new_gti, new_gt1)
Expand Down Expand Up @@ -1797,7 +1799,8 @@ def _truncate_by_index(self, start, stop):
dtstop = self.dt[-1]

gti = cross_two_gtis(
self.gti, np.asarray([[new_ts.time[0] - 0.5 * dtstart, new_ts.time[-1] + 0.5 * dtstop]])
self.gti,
np.asanyarray([[new_ts.time[0] - 0.5 * dtstart, new_ts.time[-1] + 0.5 * dtstop]]),
)

new_ts.gti = gti
Expand Down Expand Up @@ -2108,7 +2111,7 @@ def rebin(self, dt_new=None, f=None, method="sum"):
elif f is not None:
dt_new = f * self.dt

if np.any(dt_new < np.asarray(self.dt)):
if np.any(dt_new < np.asanyarray(self.dt)):
raise ValueError("The new time resolution must be larger than the old one!")

gti_new = []
Expand Down Expand Up @@ -2150,7 +2153,7 @@ def rebin(self, dt_new=None, f=None, method="sum"):

if len(gti_new) == 0:
raise ValueError("No valid GTIs after rebin.")
new_ts.gti = np.asarray(gti_new)
new_ts.gti = np.asanyarray(gti_new)

for attr in self.meta_attrs():
if attr == "dt":
Expand Down Expand Up @@ -2654,7 +2657,7 @@ def analyze_segments(self, func, segment_size, fraction_step=1, **kwargs):
res = np.nan
else:
lc_filt = self[st:sp]
lc_filt.gti = np.asarray([[tst, tsp]])
lc_filt.gti = np.asanyarray([[tst, tsp]])

res = func(lc_filt, **kwargs)
results.append(res)
Expand Down
12 changes: 6 additions & 6 deletions stingray/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,12 @@ def __init__(
StingrayTimeseries.__init__(
self,
time=time,
energy=None if energy is None else np.asarray(energy),
energy=None if energy is None else np.asanyarray(energy),
mjdref=mjdref,
dt=dt,
notes=notes,
gti=np.asarray(gti) if gti is not None else None,
pi=None if pi is None else np.asarray(pi),
gti=np.asanyarray(gti) if gti is not None else None,
pi=None if pi is None else np.asanyarray(pi),
high_precision=high_precision,
mission=mission,
instr=instr,
Expand Down Expand Up @@ -367,7 +367,7 @@ def to_lc_iter(self, dt, segment_size=None):
self.time[idx_st : idx_end + 1],
dt,
tstart=st,
gti=np.asarray([[st, end]]),
gti=np.asanyarray([[st, end]]),
tseg=tseg,
mjdref=self.mjdref,
use_hist=True,
Expand Down Expand Up @@ -474,8 +474,8 @@ def simulate_energies(self, spectrum, use_spline=False):
return

if isinstance(spectrum, list) or isinstance(spectrum, np.ndarray):
energy = np.asarray(spectrum)[0]
fluxes = np.asarray(spectrum)[1]
energy = np.asanyarray(spectrum)[0]
fluxes = np.asanyarray(spectrum)[1]

if not isinstance(energy, np.ndarray):
raise IndexError("Spectrum must be a 2-d array or list")
Expand Down
8 changes: 4 additions & 4 deletions stingray/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def integrate_power_in_frequency_range(
if power_err is None:
power_err_to_integrate = powers_to_integrate / np.sqrt(m)
else:
power_err_to_integrate = np.asarray(power_err)[frequency_mask]
power_err_to_integrate = np.asanyarray(power_err)[frequency_mask]

power_integrated = np.sum((powers_to_integrate - poisson_power) * dfs_to_integrate)
power_err_integrated = np.sqrt(np.sum((power_err_to_integrate * dfs_to_integrate) ** 2))
Expand Down Expand Up @@ -1250,7 +1250,7 @@ def get_average_ctrate(times, gti, segment_size, counts=None):
Examples
--------
>>> times = np.sort(np.random.uniform(0, 1000, 1000))
>>> gti = np.asarray([[0, 1000]])
>>> gti = np.asanyarray([[0, 1000]])
>>> counts, _ = np.histogram(times, bins=np.linspace(0, 1000, 11))
>>> bin_times = np.arange(50, 1000, 100)
>>> assert get_average_ctrate(bin_times, gti, 1000, counts=counts) == 1.0
Expand Down Expand Up @@ -1326,7 +1326,7 @@ def get_flux_iterable_from_segments(
dt = np.median(np.diff(times[:100]))

if binned:
fluxes = np.asarray(fluxes)
fluxes = np.asanyarray(fluxes)
if np.iscomplexobj(fluxes):
cast_kind = complex

Expand Down Expand Up @@ -2399,7 +2399,7 @@ def lsft_slow(
An array of Fourier transformed data.
"""
y_ = y - np.mean(y)
freqs = np.asarray(freqs[np.asarray(freqs) >= 0])
freqs = np.asanyarray(freqs[np.asanyarray(freqs) >= 0])

ft_real = np.zeros_like(freqs)
ft_imag = np.zeros_like(freqs)
Expand Down
Loading

0 comments on commit 379ed3a

Please sign in to comment.