Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobachetti committed Dec 3, 2024
1 parent e8c459b commit 6828129
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 0 deletions.
35 changes: 35 additions & 0 deletions stingray/crossspectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,7 @@ def from_time_array(
silent=False,
fullspec=False,
use_common_mean=True,
channels_overlap=False,
):
"""Calculate AveragedCrossspectrum from two arrays of event times.
Expand Down Expand Up @@ -1266,6 +1267,11 @@ def from_time_array(
power_type : str, default 'all'
If 'all', give complex powers. If 'abs', the absolute value; if 'real',
the real part
channels_overlap: bool, default False
If True, the reference band contains all the photons of the subject band.
This happens, for example, when calculating covariance spectra (see, e.g.,
the docs for ``CovarianceSpectrum``). This will generally be false in a
single cross spectrum.
"""

return crossspectrum_from_time_array(
Expand All @@ -1279,6 +1285,7 @@ def from_time_array(
silent=silent,
fullspec=fullspec,
use_common_mean=use_common_mean,
channels_overlap=channels_overlap,
)

@staticmethod
Expand All @@ -1293,6 +1300,7 @@ def from_events(
fullspec=False,
use_common_mean=True,
gti=None,
channels_overlap=False,
):
"""Calculate AveragedCrossspectrum from two event lists
Expand Down Expand Up @@ -1334,6 +1342,11 @@ def from_events(
input object GTIs! If you're getting errors regarding your GTIs,
don't use this and only give GTIs to the input objects before
making the cross spectrum.
channels_overlap: bool, default False
If True, the reference band contains all the photons of the subject band.
This happens, for example, when calculating covariance spectra (see, e.g.,
the docs for ``CovarianceSpectrum``). This will generally be false in a
single cross spectrum.
"""

return crossspectrum_from_events(
Expand All @@ -1347,6 +1360,7 @@ def from_events(
fullspec=fullspec,
use_common_mean=use_common_mean,
gti=gti,
channels_overlap=channels_overlap,
)

@staticmethod
Expand All @@ -1360,6 +1374,7 @@ def from_lightcurve(
fullspec=False,
use_common_mean=True,
gti=None,
channels_overlap=False,
):
"""Calculate AveragedCrossspectrum from two light curves
Expand Down Expand Up @@ -1398,6 +1413,11 @@ def from_lightcurve(
input object GTIs! If you're getting errors regarding your GTIs,
don't use this and only give GTIs to the input objects before
making the cross spectrum.
channels_overlap: bool, default False
If True, the reference band contains all the photons of the subject band.
This happens, for example, when calculating covariance spectra (see, e.g.,
the docs for ``CovarianceSpectrum``). This will generally be false in a
single cross spectrum.
"""
return crossspectrum_from_lightcurve(
lc1,
Expand All @@ -1409,6 +1429,7 @@ def from_lightcurve(
fullspec=fullspec,
use_common_mean=use_common_mean,
gti=gti,
channels_overlap=channels_overlap,
)

@staticmethod
Expand All @@ -1424,6 +1445,7 @@ def from_stingray_timeseries(
fullspec=False,
use_common_mean=True,
gti=None,
channels_overlap=False,
):
"""Calculate AveragedCrossspectrum from two light curves
Expand Down Expand Up @@ -1466,6 +1488,11 @@ def from_stingray_timeseries(
input object GTIs! If you're getting errors regarding your GTIs,
don't use this and only give GTIs to the input objects before
making the cross spectrum.
channels_overlap: bool, default False
If True, the reference band contains all the photons of the subject band.
This happens, for example, when calculating covariance spectra (see, e.g.,
the docs for ``CovarianceSpectrum``). This will generally be false in a
single cross spectrum.
"""
return crossspectrum_from_timeseries(
ts1,
Expand All @@ -1479,6 +1506,7 @@ def from_stingray_timeseries(
fullspec=fullspec,
use_common_mean=use_common_mean,
gti=gti,
channels_overlap=channels_overlap,
)

@staticmethod
Expand All @@ -1493,6 +1521,7 @@ def from_lc_iterable(
fullspec=False,
use_common_mean=True,
gti=None,
channels_overlap=False,
):
"""Calculate AveragedCrossspectrum from two light curves
Expand Down Expand Up @@ -1537,6 +1566,11 @@ def from_lc_iterable(
save_all : bool, default False
If True, save the cross spectrum of each segment in the ``cs_all``
attribute of the output :class:`Crossspectrum` object.
channels_overlap: bool, default False
If True, the reference band contains all the photons of the subject band.
This happens, for example, when calculating covariance spectra (see, e.g.,
the docs for ``CovarianceSpectrum``). This will generally be false in a
single cross spectrum.
"""

return crossspectrum_from_lc_iterable(
Expand All @@ -1550,6 +1584,7 @@ def from_lc_iterable(
fullspec=fullspec,
use_common_mean=use_common_mean,
gti=gti,
channels_overlap=channels_overlap,
)

def _initialize_from_any_input(
Expand Down
183 changes: 183 additions & 0 deletions stingray/tests/test_crossspectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,3 +1678,186 @@ def test_shift_and_add(self):
assert np.array_equal(output.m, [2, 3, 3, 3, 2])
assert np.array_equal(output.power, [2.0, 2.0, 5.0, 2.0, 1.5])
assert np.allclose(output.freq, [0.05, 0.15, 0.25, 0.35, 0.45])


class TestAveragedCrossspectrumOverlap(object):
def setup_class(self):
tstart = 0.0
tend = 1.0
dt = np.longdouble(0.0001)

time = np.arange(tstart + 0.5 * dt, tend + 0.5 * dt, dt)

counts1 = np.random.poisson(1, size=time.shape[0])
counts2 = np.random.poisson(1, size=time.shape[0]) + counts1

self.lc1 = Lightcurve(time, counts1, gti=[[tstart, tend]], dt=dt)
self.lc2 = Lightcurve(time, counts2, gti=[[tstart, tend]], dt=dt)

self.cs = AveragedCrossspectrum(
self.lc1, self.lc2, segment_size=1, save_all=True, channels_overlap=True
)

def test_save_all(self):
cs = AveragedCrossspectrum(
self.lc1, self.lc2, segment_size=1, save_all=True, channels_overlap=True
)
assert hasattr(self.cs, "cs_all")

def test_rebin_with_valid_type_attribute(self):
new_df = 2
aps = AveragedCrossspectrum(
self.lc1, self.lc2, segment_size=1, norm="leahy", channels_overlap=True
)

assert aps.rebin(df=new_df)

@pytest.mark.parametrize("err_dist", ["poisson", "gauss"])
def test_with_iterable_of_lightcurves(self, err_dist):
def iter_lc(lc, n):
"Generator of n parts of lc."
t0 = int(len(lc) / n)
t = t0
i = 0
while True:
lc_seg = lc[i:t]
yield lc_seg
if t + t0 > len(lc):
break
else:
i, t = t, t + t0

lc1 = copy.deepcopy(self.lc1)
lc2 = copy.deepcopy(self.lc2)
lc1.err_dist = lc2.err_dist = err_dist
with pytest.warns(UserWarning) as record:
cs = AveragedCrossspectrum(
iter_lc(self.lc1, 1), iter_lc(self.lc2, 1), segment_size=1, channels_overlap=True
)
message = "The averaged Cross spectrum from a generator "

assert np.any([message in r.message.args[0] for r in record])

def test_with_multiple_lightcurves_variable_length(self):
gti = [[0, 0.05], [0.05, 0.5], [0.555, 1.0]]
lc1 = copy.deepcopy(self.lc1)
lc1.gti = gti
lc2 = copy.deepcopy(self.lc2)
lc2.gti = gti

lc1_split = lc1.split_by_gti()
lc2_split = lc2.split_by_gti()

cs = AveragedCrossspectrum(
lc1_split,
lc2_split,
segment_size=0.05,
norm="leahy",
silent=True,
channels_overlap=True,
)

def test_coherence(self):
with pytest.warns(UserWarning) as w:
coh = self.cs.coherence()

assert len(coh[0]) == 4999
assert len(coh[1]) == 4999
assert issubclass(w[-1].category, UserWarning)

def test_normalize_crossspectrum(self):
cs1 = Crossspectrum(self.lc1, self.lc2, norm="leahy", channels_overlap=True)
cs2 = Crossspectrum(
self.lc1, self.lc2, norm="leahy", power_type="all", channels_overlap=True
)
cs3 = Crossspectrum(
self.lc1, self.lc2, norm="leahy", power_type="real", channels_overlap=True
)
cs4 = Crossspectrum(
self.lc1, self.lc2, norm="leahy", power_type="absolute", channels_overlap=True
)
assert np.allclose(cs1.power.real, cs3.power)
assert np.all(np.isclose(np.abs(cs2.power), cs4.power, atol=0.0001))

def test_normalize_crossspectrum_with_method_inplace(self):
cs1 = AveragedCrossspectrum.from_lightcurve(
self.lc1, self.lc2, segment_size=1, norm="abs", channels_overlap=True
)
cs2 = cs1.to_norm("leahy", inplace=True)
cs3 = cs1.to_norm("leahy", inplace=False)
assert cs3 is not cs1
assert cs2 is cs1

@pytest.mark.parametrize("norm1", ["leahy", "abs", "frac", "none"])
@pytest.mark.parametrize("norm2", ["leahy", "abs", "frac", "none"])
def test_normalize_crossspectrum_with_method(self, norm1, norm2):
cs1 = AveragedCrossspectrum.from_lightcurve(
self.lc1, self.lc2, segment_size=1, norm=norm1, channels_overlap=True
)
cs2 = AveragedCrossspectrum.from_lightcurve(
self.lc1, self.lc2, segment_size=1, norm=norm2, channels_overlap=True
)
cs3 = cs2.to_norm(norm1)
for attr in ["power", "power_err", "unnorm_power", "unnorm_power_err"]:
assert np.allclose(getattr(cs1, attr), getattr(cs3, attr))
assert np.allclose(getattr(cs1.pds1, attr), getattr(cs3.pds1, attr))
assert np.allclose(getattr(cs1.pds2, attr), getattr(cs3.pds2, attr))

@pytest.mark.parametrize("f", [None, 1.5])
@pytest.mark.parametrize("norm", ["leahy", "abs", "frac", "none"])
def test_rebin_factor_rebins_all_attrs(self, f, norm):
cs1 = AveragedCrossspectrum.from_lightcurve(
self.lc1, self.lc2, segment_size=1, norm=norm, channels_overlap=True
)
# N.B.: if f is not None, df gets ignored.
new_cs = cs1.rebin(df=1.5, f=f)
N = new_cs.freq.size
for attr in ["power", "power_err", "unnorm_power", "unnorm_power_err"]:
assert hasattr(new_cs, attr) and getattr(new_cs, attr).size == N
assert hasattr(new_cs.pds1, attr) and getattr(new_cs.pds1, attr).size == N
assert hasattr(new_cs.pds2, attr) and getattr(new_cs.pds2, attr).size == N

for attr in cs1.meta_attrs():
if attr not in ["df", "gti", "m"]:
assert getattr(cs1, attr) == getattr(new_cs, attr)

@pytest.mark.parametrize("norm", ["leahy", "abs", "frac", "none"])
def test_rebin_factor_log_rebins_all_attrs(self, norm):
cs1 = AveragedCrossspectrum.from_lightcurve(
self.lc1, self.lc2, segment_size=1, norm=norm, channels_overlap=True
)
new_cs = cs1.rebin_log(0.03)
N = new_cs.freq.size
for attr in ["power", "power_err", "unnorm_power", "unnorm_power_err"]:
assert hasattr(new_cs, attr) and getattr(new_cs, attr).size == N
assert hasattr(new_cs.pds1, attr) and getattr(new_cs.pds1, attr).size == N
assert hasattr(new_cs.pds2, attr) and getattr(new_cs.pds2, attr).size == N

for attr in cs1.meta_attrs():
if attr not in ["df", "gti", "m", "k"]:
assert np.all(getattr(cs1, attr) == getattr(new_cs, attr))

def test_rebin(self):
new_cs = self.cs.rebin(df=1.5)
assert hasattr(new_cs, "dt") and new_cs.dt is not None
assert new_cs.df == 1.5
new_cs.time_lag()

def test_rebin_factor(self):
new_cs = self.cs.rebin(f=1.5)
assert hasattr(new_cs, "dt") and new_cs.dt is not None
assert new_cs.df == self.cs.df * 1.5
new_cs.time_lag()

def test_rebin_log(self):
# For now, just verify that it doesn't crash
new_cs = self.cs.rebin_log(f=0.1)
assert hasattr(new_cs, "dt") and new_cs.dt is not None
assert isinstance(new_cs, type(self.cs))
new_cs.time_lag()

def test_rebin_log_returns_complex_values_and_errors(self):
# For now, just verify that it doesn't crash
new_cs = self.cs.rebin_log(f=0.1)
assert np.iscomplexobj(new_cs.power[0])
assert np.iscomplexobj(new_cs.power_err[0])

0 comments on commit 6828129

Please sign in to comment.