Skip to content

Commit

Permalink
Allow the save_all option
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobachetti committed Sep 15, 2023
1 parent e2101c3 commit 770f71b
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 25 deletions.
100 changes: 75 additions & 25 deletions stingray/powerspectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def __init__(self, data=None, norm="frac", gti=None, dt=None, lc=None, skip_chec
if data is None:
data = lc

good_input = True
if not skip_checks:
good_input = data is not None
if good_input and not skip_checks:
good_input = self.initial_checks(
data1=data, data2=data, norm=norm, gti=gti, lc1=lc, lc2=lc, dt=dt
)
Expand All @@ -122,7 +122,7 @@ def __init__(self, data=None, norm="frac", gti=None, dt=None, lc=None, skip_chec
if not good_input:
return self._initialize_empty()

if not data is not None:
if data is not None:
return self._initialize_from_any_input(data, dt=dt, norm=norm)

raise ValueError("No valid data provided!")
Expand Down Expand Up @@ -664,6 +664,7 @@ def _initialize_from_any_input(
norm="frac",
silent=False,
use_common_mean=True,
save_all=False,
):
"""
Initialize the class, trying to understand the input types.
Expand All @@ -682,6 +683,7 @@ def _initialize_from_any_input(
silent=silent,
use_common_mean=use_common_mean,
gti=gti,
save_all=save_all,
)
elif isinstance(data, Lightcurve):
spec = powerspectrum_from_lightcurve(
Expand All @@ -691,6 +693,7 @@ def _initialize_from_any_input(
silent=silent,
use_common_mean=use_common_mean,
gti=gti,
save_all=save_all,
)
spec.lc1 = data
elif isinstance(data, (tuple, list)):
Expand All @@ -706,6 +709,7 @@ def _initialize_from_any_input(
silent=silent,
use_common_mean=use_common_mean,
gti=gti,
save_all=save_all,
)
else: # pragma: no cover
raise TypeError(f"Bad inputs to Powerspectrum: {type(data)}")
Expand Down Expand Up @@ -835,8 +839,8 @@ def __init__(
if data is None:
data = lc

good_input = True
if not skip_checks:
good_input = data is not None
if good_input and not skip_checks:
good_input = self.initial_checks(
data1=data,
data2=data,
Expand Down Expand Up @@ -870,14 +874,6 @@ def __init__(
)
data = list(data)

# The large_data option requires the legacy interface.
if save_all:
warnings.warn(
"The save_all options are"
" only available with the legacy interface"
" (legacy=True)."
)

if data is not None:
return self._initialize_from_any_input(
data,
Expand All @@ -886,6 +882,7 @@ def __init__(
norm=norm,
silent=silent,
use_common_mean=use_common_mean,
save_all=save_all,
)
raise ValueError("No valid input data!")

Expand Down Expand Up @@ -985,13 +982,13 @@ def _make_matrix(self, lc):
The :class:`Lightcurve` object from which to generate the dynamical
power spectrum.
"""
ps_all, _ = AveragedPowerspectrum._make_segment_spectrum(self, lc, self.segment_size)
self.dyn_ps = np.array([ps.power for ps in ps_all]).T
avg = AveragedPowerspectrum(
lc, segment_size=self.segment_size, norm=self.norm, gti=self.gti, save_all=True
)
self.dyn_ps = np.array(avg.cs_all).T

self.freq = ps_all[0].freq
current_gti = lc.gti
if self.gti is not None:
current_gti = cross_two_gtis(self.gti, current_gti)
self.freq = avg.freq
current_gti = avg.gti

start_inds, end_inds = bin_intervals_from_gtis(
current_gti, self.segment_size, lc.time, dt=lc.dt
Expand Down Expand Up @@ -1125,7 +1122,14 @@ def rebin_time(self, dt_new, method="sum"):


def powerspectrum_from_time_array(
times, dt, segment_size=None, gti=None, norm="frac", silent=False, use_common_mean=True
times,
dt,
segment_size=None,
gti=None,
norm="frac",
silent=False,
use_common_mean=True,
save_all=False,
):
"""
Calculate a power spectrum from an array of event times.
Expand Down Expand Up @@ -1160,6 +1164,10 @@ def powerspectrum_from_time_array(
to calculate it on a per-segment basis.
silent : bool, default False
Silence the progress bars.
save_all : bool, default False
Save all intermediate PDSs used for the final average. Use with care.
This is likely to fill up your RAM on medium-sized datasets, and to
slow down the computation when rebinning.
Returns
-------
Expand All @@ -1170,14 +1178,28 @@ def powerspectrum_from_time_array(
# Suppress progress bar for single periodogram
silent = silent or (segment_size is None)
table = avg_pds_from_events(
times, gti, segment_size, dt, norm=norm, use_common_mean=use_common_mean, silent=silent
times,
gti,
segment_size,
dt,
norm=norm,
use_common_mean=use_common_mean,
silent=silent,
return_subcs=save_all,
)

return _create_powerspectrum_from_result_table(table, force_averaged=force_averaged)


def powerspectrum_from_events(
events, dt, segment_size=None, gti=None, norm="frac", silent=False, use_common_mean=True
events,
dt,
segment_size=None,
gti=None,
norm="frac",
silent=False,
use_common_mean=True,
save_all=False,
):
"""
Calculate a power spectrum from an event list.
Expand Down Expand Up @@ -1212,6 +1234,10 @@ def powerspectrum_from_events(
to calculate it on a per-segment basis.
silent : bool, default False
Silence the progress bars.
save_all : bool, default False
Save all intermediate PDSs used for the final average. Use with care.
This is likely to fill up your RAM on medium-sized datasets, and to
slow down the computation when rebinning.
Returns
-------
Expand All @@ -1228,11 +1254,12 @@ def powerspectrum_from_events(
norm=norm,
silent=silent,
use_common_mean=use_common_mean,
save_all=save_all,
)


def powerspectrum_from_lightcurve(
lc, segment_size=None, gti=None, norm="frac", silent=False, use_common_mean=True
lc, segment_size=None, gti=None, norm="frac", silent=False, use_common_mean=True, save_all=False
):
"""
Calculate a power spectrum from a light curve
Expand Down Expand Up @@ -1267,6 +1294,10 @@ def powerspectrum_from_lightcurve(
to calculate it on a per-segment basis.
silent : bool, default False
Silence the progress bars.
save_all : bool, default False
Save all intermediate PDSs used for the final average. Use with care.
This is likely to fill up your RAM on medium-sized datasets, and to
slow down the computation when rebinning.
Returns
-------
Expand All @@ -1292,13 +1323,21 @@ def powerspectrum_from_lightcurve(
silent=silent,
fluxes=lc.counts,
errors=err,
return_subcs=save_all,
)

return _create_powerspectrum_from_result_table(table, force_averaged=force_averaged)


def powerspectrum_from_lc_iterable(
iter_lc, dt, segment_size=None, gti=None, norm="frac", silent=False, use_common_mean=True
iter_lc,
dt,
segment_size=None,
gti=None,
norm="frac",
silent=False,
use_common_mean=True,
save_all=False,
):
"""
Calculate an average power spectrum from an iterable collection of light
Expand Down Expand Up @@ -1335,6 +1374,8 @@ def powerspectrum_from_lc_iterable(
to calculate it on a per-segment basis.
silent : bool, default False
Silence the progress bars.
save_all : bool, default False
Save all intermediate PDSs used for the final average. Use with care.
Returns
-------
Expand Down Expand Up @@ -1373,7 +1414,12 @@ def iterate_lc_counts(iter_lc):
)

table = avg_pds_from_iterable(
iterate_lc_counts(iter_lc), dt, norm=norm, use_common_mean=use_common_mean, silent=silent
iterate_lc_counts(iter_lc),
dt,
norm=norm,
use_common_mean=use_common_mean,
silent=silent,
return_subcs=save_all,
)
return _create_powerspectrum_from_result_table(table, force_averaged=force_averaged)

Expand Down Expand Up @@ -1415,6 +1461,10 @@ def _create_powerspectrum_from_result_table(table, force_averaged=False):
for attr, val in table.meta.items():
setattr(cs, attr, val)

if "subcs" in table.meta:
cs.cs_all = np.array(table.meta["subcs"])
cs.unnorm_cs_all = np.array(table.meta["unnorm_subcs"])

cs.err_dist = "poisson"
if hasattr(cs, "variance") and cs.variance is not None:
cs.err_dist = "gauss"
Expand Down
5 changes: 5 additions & 0 deletions stingray/tests/test_powerspectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def setup_class(cls):

cls.leahy_pds_sng = Powerspectrum(cls.lc, dt=cls.dt, norm="leahy")

def test_save_all(self):
cs = AveragedPowerspectrum(self.lc, dt=self.dt, segment_size=1, save_all=True)
assert hasattr(cs, "cs_all")

@pytest.mark.parametrize("norm", ["leahy", "frac", "abs", "none"])
def test_common_mean_gives_comparable_scatter(self, norm):
acs = AveragedPowerspectrum(
Expand Down Expand Up @@ -1045,6 +1049,7 @@ def test_rebin_time_default_method(self):
rebin_time = np.array([2.0, 6.0, 10.0])
rebin_dps = np.array([[0.7962963, 1.16402116, 0.28571429]])
dps = DynamicalPowerspectrum(self.lc_test, segment_size=segment_size)
print(dps.dyn_ps)
new_dps = dps.rebin_time(dt_new=dt_new)
assert np.allclose(new_dps.time, rebin_time)
assert np.allclose(new_dps.dyn_ps, rebin_dps)
Expand Down

0 comments on commit 770f71b

Please sign in to comment.