Skip to content

Commit

Permalink
Condense test cases in one
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobachetti committed Nov 9, 2024
1 parent 15dd360 commit 0da6b2a
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions stingray/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,20 +236,32 @@ def test_read_fits_timeseries(self):
all_ev = reader[:]
assert np.all((all_ev.time > 80000000) & (all_ev.time < 80001024))

def test_read_apply_single_gti_lists(self):
@pytest.mark.parametrize("root_file_name", [None, "test"])
@pytest.mark.parametrize("gti_kind", ["same", "one", "multiple"])
def test_read_apply_gti_lists(self, root_file_name, gti_kind):
reader = FITSTimeseriesReader(self.fname, output_class=EventList)

evs = list(reader.apply_gti_lists([[[80000000, 80001024]]]))
assert len(evs) == 1
assert np.allclose(evs[0].gti, [[80000000, 80001024]])

def test_read_apply_multiple_gti_lists(self):
reader = FITSTimeseriesReader(self.fname, output_class=EventList)

evs = list(reader.apply_gti_lists([[[80000000, 80000512]], [[80000513, 80001024]]]))
assert len(evs) == 2
assert np.allclose(evs[0].gti, [[80000000, 80000512]])
assert np.allclose(evs[1].gti, [[80000513, 80001024]])
if gti_kind == "same":
gti_list = [reader.gti]
elif gti_kind == "one":
gti_list = [[[80000000, 80001024]]]
elif gti_kind == "multiple":
gti_list = [[[80000000, 80000512]], [[80000513, 80001024]]]

evs = list(reader.apply_gti_lists(gti_list, root_file_name=root_file_name))

# Check that the number of event lists is the same as the number of GTI lists we input
assert len(evs) == len(gti_list)

# If the root_file_name is not None, read the event lists and delete the file(s)
if root_file_name is not None:
ev_str = evs
evs = [EventList.read(ev) for ev in ev_str]
for ev in ev_str:
os.unlink(ev)

for i, ev in enumerate(evs):
# Check that the gtis of the output event lists are the same we input
assert np.allclose(ev.gti, gti_list[i])

def test_read_fits_timeseries_by_nsamples(self):
reader = FITSTimeseriesReader(self.fname, output_class=EventList)
Expand Down

0 comments on commit 0da6b2a

Please sign in to comment.