Skip to content

Commit

Permalink
Better tests
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobachetti committed Sep 8, 2024
1 parent 9458e42 commit 08584f5
Showing 1 changed file with 90 additions and 29 deletions.
119 changes: 90 additions & 29 deletions stingray/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,59 +1540,91 @@ def setup_class(cls):
gti=80000000 + np.asarray([[0, 1025]]),
)

def test_read_fits_timeseries(self):
def test_stream_timeseries(self):
assert np.all((self.events.time > 80000000) & (self.events.time < 80001024))

def test_read_fits_timeseries_by_nsamples(self):
def test_stream_timeseries_by_gti_raises(self):
with pytest.raises(ValueError, match="You can only use only_attrs with a generator."):
list(
self.events.stream_from_gti_lists(
[[[80000100, 80001010]]], root_file_name="test", only_attrs=["time"]
)
)

def test_stream_timeseries_by_gti(self):
# Full slice
outfnames = list(self.events.stream_by_number_of_samples(500, root_file_name="test"))
assert len(outfnames) == 2
outfnames = list(
self.events.stream_from_gti_lists([[[80000100, 80001010]]], root_file_name="test")
)
assert len(outfnames) == 1
ev0 = StingrayTimeseries.read(outfnames[0], fmt=DEFAULT_FORMAT)
ev1 = StingrayTimeseries.read(outfnames[1], fmt=DEFAULT_FORMAT)
assert np.all(ev0.time < 80000512.5)
assert np.all(ev1.time > 80000512.5)
assert np.all((ev0.time > 80000100) & (ev0.time < 80001010))
assert np.all((ev0.gti == np.asarray([[80000100, 80001010]])))
for fname in outfnames:
os.unlink(fname)

def test_read_fits_timeseries_by_time_intv(self):
def test_stream_timeseries_by_gti_no_change(self):
# Full slice
outfnames = list(
self.events.stream_from_time_intervals([80000100, 80001100], root_file_name="test")
self.events.stream_from_gti_lists([self.events.gti], root_file_name="test")
)
assert len(outfnames) == 1
ev0 = StingrayTimeseries.read(outfnames[0], fmt=DEFAULT_FORMAT)
assert np.all((ev0.time > 80000100) & (ev0.time < 80001100))
assert np.all((ev0.gti >= 80000100) & (ev0.gti < 80001100))
for fname in outfnames:
os.unlink(fname)

def test_read_fits_timeseries_by_nsamples_generator(self):
assert np.allclose(ev0.time, self.events.time)
assert np.all(ev0.gti == self.events.gti)

def test_stream_timeseries_by_gti_no_change_generator(self):
# Full slice
ev0, ev1 = list(self.events.stream_by_number_of_samples(500))
evs = list(self.events.stream_from_gti_lists([self.events.gti]))

assert np.all(ev0.time < 80000512.5)
assert np.all(ev1.time > 80000512.5)
assert len(evs) == 1
ev0 = evs[0]
assert np.allclose(ev0.time, self.events.time)
assert np.all(ev0.gti == self.events.gti)

def test_read_fits_timeseries_by_time_intv_generator(self):
def test_stream_timeseries_by_gti_generator(self):
# Full slice
evs = list(self.events.stream_from_time_intervals([80000100, 80001100]))
evs = list(self.events.stream_from_gti_lists([[[80000100, 80001010]]]))
assert len(evs) == 1
ev0 = evs[0]
assert np.all((ev0.time > 80000100) & (ev0.time < 80001100))
assert np.all((ev0.gti >= 80000100) & (ev0.gti < 80001100))
assert np.all((ev0.time > 80000100) & (ev0.time < 80001010))
assert np.all((ev0.gti == np.asarray([[80000100, 80001010]])))

def test_read_fits_timeseries_by_nsamples_attrs(self):
def test_stream_timeseries_by_gti_attrs(self):
# Full slice
ev0_attr, ev1_attr = list(
self.events.stream_by_number_of_samples(500, only_attrs=["time", "energy"])
evs = list(
self.events.stream_from_gti_lists(
[[[80000100, 80000200]]], only_attrs=["time", "energy"]
)
)

assert np.all(ev0_attr[0] < 80000512.5)
assert np.all(ev1_attr[0] > 80000512.5)
assert len(evs) == 1
ev0_attr = evs[0]
assert np.all((ev0_attr[0] > 80000100) & (ev0_attr[0] < 80000200))
assert np.all(ev0_attr[1] == 0)
assert np.all(ev1_attr[1] == 1)

def test_read_fits_timeseries_by_time_intv_attrs(self):
def test_stream_timeseries_by_time_intv(self):
# Full slice
outfnames = list(
self.events.stream_from_time_intervals([80000100, 80001010], root_file_name="test")
)
assert len(outfnames) == 1
ev0 = StingrayTimeseries.read(outfnames[0], fmt=DEFAULT_FORMAT)
assert np.all((ev0.time > 80000100) & (ev0.time < 80001010))
assert np.all((ev0.gti == np.asarray([[80000100, 80001010]])))

for fname in outfnames:
os.unlink(fname)

def test_stream_timeseries_by_time_intv_generator(self):
# Full slice
evs = list(self.events.stream_from_time_intervals([80000100, 80001010]))
assert len(evs) == 1
ev0 = evs[0]
assert np.all((ev0.time > 80000100) & (ev0.time < 80001010))
assert np.all((ev0.gti == np.asarray([[80000100, 80001010]])))

def test_stream_timeseries_by_time_intv_attrs(self):
# Full slice
evs = list(
self.events.stream_from_time_intervals(
Expand All @@ -1603,3 +1635,32 @@ def test_read_fits_timeseries_by_time_intv_attrs(self):
ev0_attr = evs[0]
assert np.all((ev0_attr[0] > 80000100) & (ev0_attr[0] < 80000200))
assert np.all(ev0_attr[1] == 0)

def test_stream_timeseries_by_nsamples(self):
# Full slice
outfnames = list(self.events.stream_by_number_of_samples(500, root_file_name="test"))
assert len(outfnames) == 2
ev0 = StingrayTimeseries.read(outfnames[0], fmt=DEFAULT_FORMAT)
ev1 = StingrayTimeseries.read(outfnames[1], fmt=DEFAULT_FORMAT)
assert np.all(ev0.time < 80000512.5)
assert np.all(ev1.time > 80000512.5)
for fname in outfnames:
os.unlink(fname)

def test_stream_timeseries_by_nsamples_generator(self):
# Full slice
ev0, ev1 = list(self.events.stream_by_number_of_samples(500))

assert np.all(ev0.time < 80000512.5)
assert np.all(ev1.time > 80000512.5)

def test_stream_timeseries_by_nsamples_attrs(self):
# Full slice
ev0_attr, ev1_attr = list(
self.events.stream_by_number_of_samples(500, only_attrs=["time", "energy"])
)

assert np.all(ev0_attr[0] < 80000512.5)
assert np.all(ev1_attr[0] > 80000512.5)
assert np.all(ev0_attr[1] == 0)
assert np.all(ev1_attr[1] == 1)

0 comments on commit 08584f5

Please sign in to comment.