From 08584f5550dc7e833af72d1aafc19639d79e632b Mon Sep 17 00:00:00 2001 From: Matteo Bachetti Date: Sun, 8 Sep 2024 22:16:01 +0200 Subject: [PATCH] Better tests --- stingray/tests/test_base.py | 119 +++++++++++++++++++++++++++--------- 1 file changed, 90 insertions(+), 29 deletions(-) diff --git a/stingray/tests/test_base.py b/stingray/tests/test_base.py index 9ca6b5f15..a1437cf72 100644 --- a/stingray/tests/test_base.py +++ b/stingray/tests/test_base.py @@ -1540,59 +1540,91 @@ def setup_class(cls): gti=80000000 + np.asarray([[0, 1025]]), ) - def test_read_fits_timeseries(self): + def test_stream_timeseries(self): assert np.all((self.events.time > 80000000) & (self.events.time < 80001024)) - def test_read_fits_timeseries_by_nsamples(self): + def test_stream_timeseries_by_gti_raises(self): + with pytest.raises(ValueError, match="You can only use only_attrs with a generator."): + list( + self.events.stream_from_gti_lists( + [[[80000100, 80001010]]], root_file_name="test", only_attrs=["time"] + ) + ) + + def test_stream_timeseries_by_gti(self): # Full slice - outfnames = list(self.events.stream_by_number_of_samples(500, root_file_name="test")) - assert len(outfnames) == 2 + outfnames = list( + self.events.stream_from_gti_lists([[[80000100, 80001010]]], root_file_name="test") + ) + assert len(outfnames) == 1 ev0 = StingrayTimeseries.read(outfnames[0], fmt=DEFAULT_FORMAT) - ev1 = StingrayTimeseries.read(outfnames[1], fmt=DEFAULT_FORMAT) - assert np.all(ev0.time < 80000512.5) - assert np.all(ev1.time > 80000512.5) + assert np.all((ev0.time > 80000100) & (ev0.time < 80001010)) + assert np.all((ev0.gti == np.asarray([[80000100, 80001010]]))) for fname in outfnames: os.unlink(fname) - def test_read_fits_timeseries_by_time_intv(self): + def test_stream_timeseries_by_gti_no_change(self): # Full slice outfnames = list( - self.events.stream_from_time_intervals([80000100, 80001100], root_file_name="test") + self.events.stream_from_gti_lists([self.events.gti], root_file_name="test") ) assert len(outfnames) == 1 ev0 = StingrayTimeseries.read(outfnames[0], fmt=DEFAULT_FORMAT) - assert np.all((ev0.time > 80000100) & (ev0.time < 80001100)) - assert np.all((ev0.gti >= 80000100) & (ev0.gti < 80001100)) - for fname in outfnames: - os.unlink(fname) - def test_read_fits_timeseries_by_nsamples_generator(self): + assert np.allclose(ev0.time, self.events.time) + assert np.all(ev0.gti == self.events.gti) + + def test_stream_timeseries_by_gti_no_change_generator(self): # Full slice - ev0, ev1 = list(self.events.stream_by_number_of_samples(500)) + evs = list(self.events.stream_from_gti_lists([self.events.gti])) - assert np.all(ev0.time < 80000512.5) - assert np.all(ev1.time > 80000512.5) + assert len(evs) == 1 + ev0 = evs[0] + assert np.allclose(ev0.time, self.events.time) + assert np.all(ev0.gti == self.events.gti) - def test_read_fits_timeseries_by_time_intv_generator(self): + def test_stream_timeseries_by_gti_generator(self): # Full slice - evs = list(self.events.stream_from_time_intervals([80000100, 80001100])) + evs = list(self.events.stream_from_gti_lists([[[80000100, 80001010]]])) assert len(evs) == 1 ev0 = evs[0] - assert np.all((ev0.time > 80000100) & (ev0.time < 80001100)) - assert np.all((ev0.gti >= 80000100) & (ev0.gti < 80001100)) + assert np.all((ev0.time > 80000100) & (ev0.time < 80001010)) + assert np.all((ev0.gti == np.asarray([[80000100, 80001010]]))) - def test_read_fits_timeseries_by_nsamples_attrs(self): + def test_stream_timeseries_by_gti_attrs(self): # Full slice - ev0_attr, ev1_attr = list( - self.events.stream_by_number_of_samples(500, only_attrs=["time", "energy"]) + evs = list( + self.events.stream_from_gti_lists( + [[[80000100, 80000200]]], only_attrs=["time", "energy"] + ) ) - - assert np.all(ev0_attr[0] < 80000512.5) - assert np.all(ev1_attr[0] > 80000512.5) + assert len(evs) == 1 + ev0_attr = evs[0] + assert np.all((ev0_attr[0] > 80000100) & (ev0_attr[0] < 80000200)) assert np.all(ev0_attr[1] == 0) - assert np.all(ev1_attr[1] == 1) - def test_read_fits_timeseries_by_time_intv_attrs(self): + def test_stream_timeseries_by_time_intv(self): + # Full slice + outfnames = list( + self.events.stream_from_time_intervals([80000100, 80001010], root_file_name="test") + ) + assert len(outfnames) == 1 + ev0 = StingrayTimeseries.read(outfnames[0], fmt=DEFAULT_FORMAT) + assert np.all((ev0.time > 80000100) & (ev0.time < 80001010)) + assert np.all((ev0.gti == np.asarray([[80000100, 80001010]]))) + + for fname in outfnames: + os.unlink(fname) + + def test_stream_timeseries_by_time_intv_generator(self): + # Full slice + evs = list(self.events.stream_from_time_intervals([80000100, 80001010])) + assert len(evs) == 1 + ev0 = evs[0] + assert np.all((ev0.time > 80000100) & (ev0.time < 80001010)) + assert np.all((ev0.gti == np.asarray([[80000100, 80001010]]))) + + def test_stream_timeseries_by_time_intv_attrs(self): # Full slice evs = list( self.events.stream_from_time_intervals( @@ -1603,3 +1635,32 @@ def test_read_fits_timeseries_by_time_intv_attrs(self): ev0_attr = evs[0] assert np.all((ev0_attr[0] > 80000100) & (ev0_attr[0] < 80000200)) assert np.all(ev0_attr[1] == 0) + + def test_stream_timeseries_by_nsamples(self): + # Full slice + outfnames = list(self.events.stream_by_number_of_samples(500, root_file_name="test")) + assert len(outfnames) == 2 + ev0 = StingrayTimeseries.read(outfnames[0], fmt=DEFAULT_FORMAT) + ev1 = StingrayTimeseries.read(outfnames[1], fmt=DEFAULT_FORMAT) + assert np.all(ev0.time < 80000512.5) + assert np.all(ev1.time > 80000512.5) + for fname in outfnames: + os.unlink(fname) + + def test_stream_timeseries_by_nsamples_generator(self): + # Full slice + ev0, ev1 = list(self.events.stream_by_number_of_samples(500)) + + assert np.all(ev0.time < 80000512.5) + assert np.all(ev1.time > 80000512.5) + + def test_stream_timeseries_by_nsamples_attrs(self): + # Full slice + ev0_attr, ev1_attr = list( + self.events.stream_by_number_of_samples(500, only_attrs=["time", "energy"]) + ) + + assert np.all(ev0_attr[0] < 80000512.5) + assert np.all(ev1_attr[0] > 80000512.5) + assert np.all(ev0_attr[1] == 0) + assert np.all(ev1_attr[1] == 1)