diff --git a/docs/changes/838.feature.rst b/docs/changes/838.feature.rst new file mode 100644 index 000000000..8a66178cd --- /dev/null +++ b/docs/changes/838.feature.rst @@ -0,0 +1 @@ +Add methods to stream timeseries data into chunks diff --git a/stingray/base.py b/stingray/base.py index e8727c5aa..5ab476657 100644 --- a/stingray/base.py +++ b/stingray/base.py @@ -15,7 +15,7 @@ from astropy.units import Quantity from stingray.loggingconfig import setup_logger -from .io import _can_save_longdouble, _can_serialize_meta +from .io import _can_save_longdouble, _can_serialize_meta, DEFAULT_FORMAT from .utils import ( sqsum, assign_value_if_none, @@ -36,6 +36,7 @@ get_total_gti_length, bin_intervals_from_gtis, time_intervals_from_gtis, + split_gtis_by_exposure, ) from typing import TYPE_CHECKING, Type, TypeVar, Union @@ -1075,7 +1076,7 @@ def __getitem__(self, index): for attr in self.meta_attrs(): setattr(new_ts, attr, copy.deepcopy(getattr(self, attr))) - for attr in self.array_attrs() + [self.main_array_attr]: + for attr in self.array_attrs() + self.internal_array_attrs() + [self.main_array_attr]: setattr(new_ts, attr, getattr(self, attr)[start:stop:step]) return new_ts @@ -1402,26 +1403,149 @@ def split_by_gti(self, gti=None, min_points=2): if gti is None: gti = self.gti - list_of_tss = [] + slices = [] - start_bins, stop_bins = gti_border_bins(gti, self.time, self.dt) - for i in range(len(start_bins)): - start = start_bins[i] - stop = stop_bins[i] - - if (stop - start) < min_points: + for s in self.stream_from_gti_lists([[g] for g in gti]): + if np.size(getattr(s, s.main_array_attr)) < min_points: continue + slices.append(s) + return slices + + def get_idx_from_time_range(self, start, stop): + lower_edge, upper_edge = np.searchsorted(self.time, [start, stop]) + # Searchsorted will find the first number above stop. We want the last number below stop! + return lower_edge, upper_edge - 1 + + def stream_from_gti_lists( + self, new_gti_lists, root_file_name=None, fmt=DEFAULT_FORMAT, only_attrs=None + ): + """Split the event list into different files, each with a different GTI. + + Parameters + ---------- + new_gti_lists : list of lists + A list of lists of GTIs. Each sublist should contain a list of GTIs + for a new file. + + Other Parameters + ---------------- + root_file_name : str, default None + The root name of the output files. The file name will be appended with + "_00", "_01", etc. + If None, a generator is returned instead of writing the files. + fmt : str + The format of the output files. Default is 'hdf5'. + + Yields + ------ + output_files : list of str + A list of the output file names. + + """ + + if only_attrs is not None and root_file_name is not None: + raise ValueError("You can only use only_attrs with a generator.") + new_gti_lists = np.asanyarray(new_gti_lists) + if len(new_gti_lists[0]) == len(self.gti) and np.all( + np.abs(np.asanyarray(new_gti_lists[0]).flatten() - self.gti.flatten()) < 1e-3 + ): + logger.info("No change of GTI") + if only_attrs is not None: + yield [copy.deepcopy(getattr(self, attr)) for attr in only_attrs] + else: + ev = self[:] + if root_file_name is None: + yield ev + else: + output_file = root_file_name + f"_00." + fmt.lstrip(".") + ev.write(output_file, fmt=fmt) + yield output_file + else: + for i, gti in enumerate(new_gti_lists): + if len(gti) == 0: + continue - new_gti = np.array([gti[i]]) - mask = create_gti_mask(self.time, new_gti) + lower_edge, upper_edge = self.get_idx_from_time_range(gti[0, 0], gti[-1, 1]) - # Note: GTIs are consistent with default in this case! - new_ts = self.apply_mask(mask) - new_ts.gti = new_gti + if only_attrs is not None: + yield [ + copy.deepcopy(getattr(self, attr)[lower_edge : upper_edge + 1]) + for attr in only_attrs + ] + else: + ev = self[lower_edge : upper_edge + 1] + ev.gti = gti + + if root_file_name is not None: + new_file = root_file_name + f"_{i:002d}." + fmt.lstrip(".") + logger.info(f"Writing {new_file}") + ev.write(new_file, fmt=fmt) + yield new_file + else: + yield ev + + def stream_by_number_of_samples( + self, nsamples, root_file_name=None, fmt=DEFAULT_FORMAT, only_attrs=None + ): + """Split the event list into different files, each with approx. the given no. of photons. + + Parameters + ---------- + nsamples : int + The number of photons in each output file. - list_of_tss.append(new_ts) + Other Parameters + ---------------- + root_file_name : str, default None + The root name of the output files. The file name will be appended with + "_00", "_01", etc. + If None, a generator is returned instead of writing the files. + fmt : str + The format of the output files. Default is 'hdf5'. + + Yields + ------ + output_files : list of str + A list of the output file names. + """ + n_intervals = int(np.rint(self.n / nsamples)) + exposure_per_interval = self.exposure / n_intervals + new_gti_lists = split_gtis_by_exposure(self.gti, exposure_per_interval) + + return self.stream_from_gti_lists( + new_gti_lists, root_file_name=root_file_name, fmt=fmt, only_attrs=only_attrs + ) - return list_of_tss + def stream_from_time_intervals( + self, time_intervals, root_file_name=None, fmt=DEFAULT_FORMAT, only_attrs=None + ): + """Filter the event list at the given time intervals. + + Parameters + ---------- + time_intervals : 2-d float array + List of time intervals of the form ``[[time0_0, time0_1], [time1_0, time1_1], ...]`` + + Other Parameters + ---------------- + root_file_name : str, default None + The root name of the output files. The file name will be appended with + "_00", "_01", etc. + If None, a generator is returned instead of writing the files. + fmt : str + The format of the output files. Default is 'hdf5'. + + Yields + ------ + output_files : list of str + A list of the output file names. + """ + if len(np.shape(time_intervals)) == 1: + time_intervals = [time_intervals] + new_gti = [cross_two_gtis(self.gti, [t_int]) for t_int in time_intervals] + return self.stream_from_gti_lists( + new_gti, root_file_name=root_file_name, fmt=fmt, only_attrs=only_attrs + ) def to_astropy_timeseries(self) -> TimeSeries: """Save the ``StingrayTimeseries`` to an ``Astropy`` timeseries. @@ -1707,9 +1831,13 @@ def __getitem__(self, index): """ new_ts = super().__getitem__(index) - step = 1 + if isinstance(index, slice): + if index.start is None and index.stop is None and index.step is None: + return copy.deepcopy(new_ts) step = assign_value_if_none(index.step, 1) + else: + step = 1 dt = self.dt if np.isscalar(dt): @@ -2579,6 +2707,36 @@ def estimate_segment_size(self, min_counts=None, min_samples=None, even_sampling return segment_size + def get_segment_borders(self, segment_size=None, fraction_step=1): + """Get the start and stop times of the segments for segment-by-segment analysis. + + Parameters + ---------- + segment_size : float + Length in seconds of the light curve segments. If None, the full GTIs are considered + instead as segments. + fraction_step : float + If the step is not a full ``segment_size`` but less (e.g. a moving window), + this indicates the ratio between step step and ``segment_size`` (e.g. + 0.5 means that the window shifts of half ``segment_size``) + + Returns + ------- + start_times : array + Lower time boundaries of all time segments. + stop_times : array + Upper time boundaries of all segments. + + """ + if segment_size is None: + start_times = self.gti[:, 0] + stop_times = self.gti[:, 1] + else: + start_times, stop_times = time_intervals_from_gtis( + self.gti, segment_size, fraction_step=fraction_step + ) + return start_times, stop_times + def analyze_segments(self, func, segment_size, fraction_step=1, **kwargs): """Analyze segments of the light curve with any function. @@ -2631,40 +2789,19 @@ def analyze_segments(self, func, segment_size, fraction_step=1, **kwargs): >>> np.allclose(res, 10) True """ - - if segment_size is None: - start_times = self.gti[:, 0] - stop_times = self.gti[:, 1] - start = np.searchsorted(self.time, start_times) - stop = np.searchsorted(self.time, stop_times) - elif self.dt > 0: - start, stop = bin_intervals_from_gtis( - self.gti, segment_size, self.time, fraction_step=fraction_step, dt=self.dt - ) - start_times = self.time[start] - 0.5 * self.dt - # Remember that stop is one element above the last element, because - # it's defined to be used in intervals start:stop - stop_times = self.time[stop - 1] + self.dt * 1.5 - else: - start_times, stop_times = time_intervals_from_gtis( - self.gti, segment_size, fraction_step=fraction_step - ) - start = np.searchsorted(self.time, start_times) - stop = np.searchsorted(self.time, stop_times) + start_times, stop_times = self.get_segment_borders(segment_size, fraction_step) results = [] n_outs = 1 - for i, (st, sp, tst, tsp) in enumerate(zip(start, stop, start_times, stop_times)): - if sp - st <= 1: + for i, lc_filt in enumerate( + self.stream_from_time_intervals(list(zip(start_times, stop_times))) + ): + if lc_filt is None or len(lc_filt.time) <= 1: warnings.warn( - f"Segment {i} ({tst}--{tsp}) has one data point or less. Skipping it " + f"Segment {i} ({start_times[i]}--{stop_times[i]}) has one data point or less. Skipping it " ) - continue - lc_filt = self[st:sp] - lc_filt.gti = np.asanyarray([[tst, tsp]]) - res = func(lc_filt, **kwargs) results.append(res) if isinstance(res, Iterable) and not isinstance(res, str): diff --git a/stingray/fourier.py b/stingray/fourier.py index 613e687eb..dd6b00f66 100644 --- a/stingray/fourier.py +++ b/stingray/fourier.py @@ -1824,7 +1824,7 @@ def local_show_progress(a): ft1 = fft(flux1) ft2 = fft(flux2) - # Calculate the sum of each light curve, to calculate the mean + # Calculate the sum of each light curve chunk, to calculate the mean n_ph1 = flux1.sum() n_ph2 = flux2.sum() n_ph = np.sqrt(n_ph1 * n_ph2) diff --git a/stingray/gti.py b/stingray/gti.py index 6a0648f13..482699c12 100644 --- a/stingray/gti.py +++ b/stingray/gti.py @@ -1,3 +1,4 @@ +import os import re import numpy as np import warnings @@ -36,6 +37,9 @@ "gti_border_bins", "generate_indices_of_segment_boundaries_unbinned", "generate_indices_of_segment_boundaries_binned", + "split_gtis_by_exposure", + "split_gtis_at_index", + "find_large_bad_time_intervals", ] logger = setup_logger() @@ -257,6 +261,8 @@ def get_gti_from_all_extensions(lchdulist, accepted_gtistrings=["GTI"], det_numb Examples -------- + Prepare data: + >>> from astropy.io import fits >>> s1 = fits.Column(name='START', array=[0, 100, 200], format='D') >>> s2 = fits.Column(name='STOP', array=[50, 150, 250], format='D') @@ -264,12 +270,26 @@ def get_gti_from_all_extensions(lchdulist, accepted_gtistrings=["GTI"], det_numb >>> s1 = fits.Column(name='START', array=[200, 300], format='D') >>> s2 = fits.Column(name='STOP', array=[250, 350], format='D') >>> hdu2 = fits.TableHDU.from_columns([s1, s2], name='STDGTI05') - >>> lchdulist = fits.HDUList([hdu1, hdu2]) - >>> gti = get_gti_from_all_extensions( + >>> lchdulist = fits.HDUList([fits.PrimaryHDU(), hdu1, hdu2]) + >>> lchdulist.writeto("test_gti_ext.fits", overwrite=True) + + Now, try to load from the HDU list, and test the result is correct: + + >>> gti0 = get_gti_from_all_extensions( ... lchdulist, accepted_gtistrings=['GTI0', 'STDGTI'], ... det_numbers=[5]) - >>> assert np.allclose(gti, [[200, 250]]) + >>> assert np.allclose(gti0, [[200, 250]]) + + Do the same with an input file name: + + >>> gti1 = get_gti_from_all_extensions( + ... "test_gti_ext.fits", accepted_gtistrings=['GTI0', 'STDGTI'], + ... det_numbers=[5]) + >>> assert np.allclose(gti1, [[200, 250]]) + >>> os.unlink("test_gti_ext.fits") """ + if isinstance(lchdulist, str): + lchdulist = fits.open(lchdulist) acc_gti_strs = copy.deepcopy(accepted_gtistrings) if det_numbers is not None: for i in det_numbers: @@ -1687,3 +1707,159 @@ def generate_indices_of_segment_boundaries_binned(times, gti, segment_size, dt=N dt = 0 for idx0, idx1 in zip(startidx, stopidx): yield times[idx0] - dt / 2, times[min(idx1, times.size - 1)] - dt / 2, idx0, idx1 + + +def split_gtis_at_indices(gtis, index_list): + """Split a GTI list at the given indices, creating multiple GTI lists. + + Parameters + ---------- + gtis : 2-d float array + List of GTIs of the form ``[[gti0_0, gti0_1], [gti1_0, gti1_1], ...]`` + index_list : int or array-like + Index or list of indices at which to split the GTIs + + Returns + ------- + gti_lists : list of 2-d float arrays + List of GTI lists, split at the given indices + + Examples + -------- + >>> gtis = [[0, 30], [50, 60], [80, 90]] + >>> new_gtis = split_gtis_at_indices(gtis, 1) + >>> assert np.allclose(new_gtis[0], [[0, 30]]) + >>> assert np.allclose(new_gtis[1], [[50, 60], [80, 90]]) + """ + gtis = np.asanyarray(gtis) + if not isinstance(index_list, Iterable): + index_list = [index_list] + previous_idx = 0 + gti_lists = [] + if index_list[0] == 0: + index_list = index_list[1:] + for idx in index_list: + gti_lists.append(gtis[previous_idx:idx, :]) + previous_idx = idx + if index_list[-1] != -1 and index_list[-1] <= gtis[:, 0].size - 1: + gti_lists.append(gtis[previous_idx:, :]) + + return gti_lists + + +def find_large_bad_time_intervals(gtis, bti_length_limit=86400): + """Find large bad time intervals in a list of GTIs, and split the GTI list accordingly. + + Parameters + ---------- + gtis : 2-d float array + List of GTIs of the form ``[[gti0_0, gti0_1], [gti1_0, gti1_1], ...]`` + bti_length_limit : float + Maximum length of a bad time interval. If a BTI is longer than this, an edge will be + returned at the midpoint of the BTI. + + Returns + ------- + bad_interval_midpoints : list of float + List of midpoints of large bad time intervals + + Examples + -------- + >>> gtis = [[0, 30], [86450, 86460], [86480, 86490]] + >>> bad_interval_midpoints = find_large_bad_time_intervals(gtis) + >>> assert np.allclose(bad_interval_midpoints, [43240]) + """ + gtis = np.asanyarray(gtis) + bad_interval_midpoints = [] + # Check for compulsory edges + last_edge = gtis[0, 0] + for g in gtis: + if g[0] - last_edge > bti_length_limit: + logger.info(f"Detected large bad time interval between {g[0]} and {last_edge}") + bad_interval_midpoints.append((g[0] + last_edge) / 2) + last_edge = g[1] + + return bad_interval_midpoints + + +def split_gtis_by_exposure(gtis, exposure_per_chunk, new_interval_if_gti_sep=None): + """Split a list of GTIs into smaller GTI lists of a given total (approximate) exposure. + + Parameters + ---------- + gtis : 2-d float array + List of GTIs of the form ``[[gti0_0, gti0_1], [gti1_0, gti1_1], ...]`` + exposure_per_chunk : float + Total exposure of each chunk + + Other Parameters + ---------------- + new_interval_if_gti_sep : float + If the GTIs are separated by more than this time, split the observation in two. + + Returns + ------- + gti_list : list of 2-d float arrays + List of GTI lists, split into chunks of the given exposure / separated by more + than the given limit separation + + Examples + -------- + >>> gtis = [[0, 30], [86450, 86460]] + >>> new_gtis = split_gtis_by_exposure(gtis, 400, new_interval_if_gti_sep=86400) + >>> assert np.allclose(new_gtis[0], [[0, 30]]) + >>> assert np.allclose(new_gtis[1], [[86450, 86460]]) + >>> gtis = [[0, 30], [40, 70], [90, 120], [130, 160]] + >>> new_gtis = split_gtis_by_exposure(gtis, 60) + >>> assert np.allclose(new_gtis[0], [[0, 30], [40, 70]]) + >>> assert np.allclose(new_gtis[1], [[90, 120], [130, 160]]) + + """ + gtis = np.asanyarray(gtis) + total_exposure = np.sum(np.diff(gtis, axis=1)) + compulsory_edges = [] + if new_interval_if_gti_sep is not None: + compulsory_edges = find_large_bad_time_intervals(gtis, new_interval_if_gti_sep) + + base_gti_list = split_gtis_at_indices(gtis, np.searchsorted(gtis[:, 1], compulsory_edges)) + final_gti_list = [] + for local_gtis in base_gti_list: + local_split_gtis = split_gtis_by_exposure(local_gtis, exposure_per_chunk) + final_gti_list.extend(local_split_gtis) + return final_gti_list + + n_intervals = int(np.rint(total_exposure / exposure_per_chunk)) + + if n_intervals <= 1: + return np.asarray([gtis]) + + if len(gtis) <= n_intervals: + new_gtis = [] + for g in gtis: + if g[1] - g[0] > exposure_per_chunk: + new_edges = np.arange(g[0], g[1], exposure_per_chunk) + if new_edges[-1] < g[1]: + new_edges = np.append(new_edges, g[1]) + + new_gtis.extend([[ed0, ed1] for ed0, ed1 in zip(new_edges[:-1], new_edges[1:])]) + else: + new_gtis.append(g) + gtis = np.asarray(new_gtis) + + exposure_edges = [] + last_exposure = 0 + for g in gtis: + exposure_edges.append(last_exposure) + last_exposure += g[1] - g[0] + + exposure_edges = np.asarray(exposure_edges) + + total_exposure = last_exposure + + exposure_per_interval = total_exposure / n_intervals + exposure_intervals = np.arange(0, total_exposure + exposure_per_interval, exposure_per_interval) + + index_list = np.searchsorted(exposure_edges, exposure_intervals) + + vals = split_gtis_at_indices(gtis, index_list) + return vals diff --git a/stingray/io.py b/stingray/io.py index e3bb10de9..ff4f720db 100644 --- a/stingray/io.py +++ b/stingray/io.py @@ -36,12 +36,13 @@ import pickle _H5PY_INSTALLED = True +DEFAULT_FORMAT = "hdf5" try: import h5py except ImportError: _H5PY_INSTALLED = False - + DEFAULT_FORMAT = "pickle" HAS_128 = True try: diff --git a/stingray/lightcurve.py b/stingray/lightcurve.py index 367082cfc..e39e6807e 100644 --- a/stingray/lightcurve.py +++ b/stingray/lightcurve.py @@ -741,35 +741,11 @@ def __getitem__(self, index): >>> assert np.isclose(lc[2], 33) >>> assert np.allclose(lc[:2].counts, [11, 22]) """ + if isinstance(index, (int, np.integer)): return self.counts[index] elif isinstance(index, slice): - start = assign_value_if_none(index.start, 0) - stop = assign_value_if_none(index.stop, len(self.counts)) - step = assign_value_if_none(index.step, 1) - - new_counts = self.counts[start:stop:step] - new_time = self.time[start:stop:step] - - new_gti = [[self.time[start] - 0.5 * self.dt, self.time[stop - 1] + 0.5 * self.dt]] - new_gti = np.asanyarray(new_gti) - if step > 1: - new_gt1 = np.array(list(zip(new_time - self.dt / 2, new_time + self.dt / 2))) - new_gti = cross_two_gtis(new_gti, new_gt1) - new_gti = cross_two_gtis(self.gti, new_gti) - - lc = Lightcurve( - new_time, - new_counts, - mjdref=self.mjdref, - gti=new_gti, - dt=self.dt, - skip_checks=True, - err_dist=self.err_dist, - ) - if self._counts_err is not None: - lc._counts_err = self._counts_err[start:stop:step] - return lc + return super().__getitem__(index) else: raise IndexError("The index must be either an integer or a slice " "object !") diff --git a/stingray/tests/test_base.py b/stingray/tests/test_base.py index d43f79116..a1437cf72 100644 --- a/stingray/tests/test_base.py +++ b/stingray/tests/test_base.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt from astropy.table import Table from stingray.base import StingrayObject, StingrayTimeseries +from stingray.io import DEFAULT_FORMAT _HAS_XARRAY = importlib.util.find_spec("xarray") is not None _HAS_PANDAS = importlib.util.find_spec("pandas") is not None @@ -1521,3 +1522,145 @@ def func(x): assert np.allclose(results_as, results_ag) assert np.allclose(results_as, [6, 4]) + + +class TestStreaming(object): + @classmethod + def setup_class(cls): + curdir = os.path.abspath(os.path.dirname(__file__)) + cls.datadir = os.path.join(curdir, "data") + cls.fname = os.path.join(cls.datadir, "monol_testA.evt") + times = 80000000 + np.sort(np.random.uniform(0, 1024, 1000)) + energies = np.ones(1000) + energies[times < 80000512.5] = 0 + + cls.events = StingrayTimeseries( + time=times, + energy=energies, + gti=80000000 + np.asarray([[0, 1025]]), + ) + + def test_stream_timeseries(self): + assert np.all((self.events.time > 80000000) & (self.events.time < 80001024)) + + def test_stream_timeseries_by_gti_raises(self): + with pytest.raises(ValueError, match="You can only use only_attrs with a generator."): + list( + self.events.stream_from_gti_lists( + [[[80000100, 80001010]]], root_file_name="test", only_attrs=["time"] + ) + ) + + def test_stream_timeseries_by_gti(self): + # Full slice + outfnames = list( + self.events.stream_from_gti_lists([[[80000100, 80001010]]], root_file_name="test") + ) + assert len(outfnames) == 1 + ev0 = StingrayTimeseries.read(outfnames[0], fmt=DEFAULT_FORMAT) + assert np.all((ev0.time > 80000100) & (ev0.time < 80001010)) + assert np.all((ev0.gti == np.asarray([[80000100, 80001010]]))) + for fname in outfnames: + os.unlink(fname) + + def test_stream_timeseries_by_gti_no_change(self): + # Full slice + outfnames = list( + self.events.stream_from_gti_lists([self.events.gti], root_file_name="test") + ) + assert len(outfnames) == 1 + ev0 = StingrayTimeseries.read(outfnames[0], fmt=DEFAULT_FORMAT) + + assert np.allclose(ev0.time, self.events.time) + assert np.all(ev0.gti == self.events.gti) + + def test_stream_timeseries_by_gti_no_change_generator(self): + # Full slice + evs = list(self.events.stream_from_gti_lists([self.events.gti])) + + assert len(evs) == 1 + ev0 = evs[0] + assert np.allclose(ev0.time, self.events.time) + assert np.all(ev0.gti == self.events.gti) + + def test_stream_timeseries_by_gti_generator(self): + # Full slice + evs = list(self.events.stream_from_gti_lists([[[80000100, 80001010]]])) + assert len(evs) == 1 + ev0 = evs[0] + assert np.all((ev0.time > 80000100) & (ev0.time < 80001010)) + assert np.all((ev0.gti == np.asarray([[80000100, 80001010]]))) + + def test_stream_timeseries_by_gti_attrs(self): + # Full slice + evs = list( + self.events.stream_from_gti_lists( + [[[80000100, 80000200]]], only_attrs=["time", "energy"] + ) + ) + assert len(evs) == 1 + ev0_attr = evs[0] + assert np.all((ev0_attr[0] > 80000100) & (ev0_attr[0] < 80000200)) + assert np.all(ev0_attr[1] == 0) + + def test_stream_timeseries_by_time_intv(self): + # Full slice + outfnames = list( + self.events.stream_from_time_intervals([80000100, 80001010], root_file_name="test") + ) + assert len(outfnames) == 1 + ev0 = StingrayTimeseries.read(outfnames[0], fmt=DEFAULT_FORMAT) + assert np.all((ev0.time > 80000100) & (ev0.time < 80001010)) + assert np.all((ev0.gti == np.asarray([[80000100, 80001010]]))) + + for fname in outfnames: + os.unlink(fname) + + def test_stream_timeseries_by_time_intv_generator(self): + # Full slice + evs = list(self.events.stream_from_time_intervals([80000100, 80001010])) + assert len(evs) == 1 + ev0 = evs[0] + assert np.all((ev0.time > 80000100) & (ev0.time < 80001010)) + assert np.all((ev0.gti == np.asarray([[80000100, 80001010]]))) + + def test_stream_timeseries_by_time_intv_attrs(self): + # Full slice + evs = list( + self.events.stream_from_time_intervals( + [80000100, 80000200], only_attrs=["time", "energy"] + ) + ) + assert len(evs) == 1 + ev0_attr = evs[0] + assert np.all((ev0_attr[0] > 80000100) & (ev0_attr[0] < 80000200)) + assert np.all(ev0_attr[1] == 0) + + def test_stream_timeseries_by_nsamples(self): + # Full slice + outfnames = list(self.events.stream_by_number_of_samples(500, root_file_name="test")) + assert len(outfnames) == 2 + ev0 = StingrayTimeseries.read(outfnames[0], fmt=DEFAULT_FORMAT) + ev1 = StingrayTimeseries.read(outfnames[1], fmt=DEFAULT_FORMAT) + assert np.all(ev0.time < 80000512.5) + assert np.all(ev1.time > 80000512.5) + for fname in outfnames: + os.unlink(fname) + + def test_stream_timeseries_by_nsamples_generator(self): + # Full slice + ev0, ev1 = list(self.events.stream_by_number_of_samples(500)) + + assert np.all(ev0.time < 80000512.5) + assert np.all(ev1.time > 80000512.5) + + def test_stream_timeseries_by_nsamples_attrs(self): + # Full slice + ev0_attr, ev1_attr = list( + self.events.stream_by_number_of_samples(500, only_attrs=["time", "energy"]) + ) + + assert np.all(ev0_attr[0] < 80000512.5) + assert np.all(ev1_attr[0] > 80000512.5) + assert np.all(ev0_attr[1] == 0) + assert np.all(ev1_attr[1] == 1) diff --git a/stingray/tests/test_events.py b/stingray/tests/test_events.py index cfc272e22..70ba5e82a 100644 --- a/stingray/tests/test_events.py +++ b/stingray/tests/test_events.py @@ -7,6 +7,7 @@ from ..events import EventList from ..lightcurve import Lightcurve +from ..io import DEFAULT_FORMAT curdir = os.path.abspath(os.path.dirname(__file__)) datadir = os.path.join(curdir, "data") @@ -674,3 +675,76 @@ def test_intensity_no_segment(self): assert np.allclose(rate_errs, 0.003, atol=0.001) assert np.allclose(start, 0) assert np.allclose(stop, 100000) + + +class TestStreaming(object): + @classmethod + def setup_class(cls): + curdir = os.path.abspath(os.path.dirname(__file__)) + cls.datadir = os.path.join(curdir, "data") + cls.fname = os.path.join(cls.datadir, "monol_testA.evt") + cls.events = EventList.read(cls.fname, fmt="hea") + + def test_read_fits_timeseries(self): + assert np.all((self.events.time > 80000000) & (self.events.time < 80001024)) + + def test_read_fits_timeseries_by_nsamples(self): + # Full slice + outfnames = list(self.events.stream_by_number_of_samples(500, root_file_name="test")) + assert len(outfnames) == 2 + ev0 = EventList.read(outfnames[0], fmt=DEFAULT_FORMAT) + ev1 = EventList.read(outfnames[1], fmt=DEFAULT_FORMAT) + assert np.all(ev0.time < 80000512.5) + assert np.all(ev1.time > 80000512.5) + for fname in outfnames: + os.unlink(fname) + + def test_read_fits_timeseries_by_time_intv(self): + # Full slice + outfnames = list( + self.events.stream_from_time_intervals([80000100, 80001100], root_file_name="test") + ) + assert len(outfnames) == 1 + ev0 = EventList.read(outfnames[0], fmt=DEFAULT_FORMAT) + assert np.all((ev0.time > 80000100) & (ev0.time < 80001100)) + assert np.all((ev0.gti >= 80000100) & (ev0.gti < 80001100)) + for fname in outfnames: + os.unlink(fname) + + def test_read_fits_timeseries_by_nsamples_generator(self): + # Full slice + ev0, ev1 = list(self.events.stream_by_number_of_samples(500)) + + assert np.all(ev0.time < 80000512.5) + assert np.all(ev1.time > 80000512.5) + + def test_read_fits_timeseries_by_time_intv_generator(self): + # Full slice + evs = list(self.events.stream_from_time_intervals([80000100, 80001100])) + assert len(evs) == 1 + ev0 = evs[0] + assert np.all((ev0.time > 80000100) & (ev0.time < 80001100)) + assert np.all((ev0.gti >= 80000100) & (ev0.gti < 80001100)) + + def test_read_fits_timeseries_by_nsamples_attrs(self): + # Full slice + ev0_attr, ev1_attr = list( + self.events.stream_by_number_of_samples(500, only_attrs=["time", "energy"]) + ) + + assert np.all(ev0_attr[0] < 80000512.5) + assert np.all(ev1_attr[0] > 80000512.5) + assert ev0_attr[0].size == ev0_attr[1].size + assert ev1_attr[0].size == ev1_attr[1].size + + def test_read_fits_timeseries_by_time_intv_attrs(self): + # Full slice + evs = list( + self.events.stream_from_time_intervals( + [80000100, 80000200], only_attrs=["time", "energy"] + ) + ) + assert len(evs) == 1 + ev0_attr = evs[0] + assert np.all((ev0_attr[0] > 80000100) & (ev0_attr[0] < 80000200)) + assert ev0_attr[0].size == ev0_attr[1].size diff --git a/stingray/tests/test_gti.py b/stingray/tests/test_gti.py index 1cf29cda1..1848d08f4 100644 --- a/stingray/tests/test_gti.py +++ b/stingray/tests/test_gti.py @@ -7,6 +7,7 @@ from stingray.gti import create_gti_from_condition, gti_len, gti_border_bins from stingray.gti import time_intervals_from_gtis, bin_intervals_from_gtis from stingray.gti import create_gti_mask_complete, join_equal_gti_boundaries +from stingray.gti import split_gtis_at_indices, split_gtis_by_exposure from stingray import StingrayError curdir = os.path.abspath(os.path.dirname(__file__)) @@ -363,6 +364,38 @@ def test_join_boundaries(self): newg = join_equal_gti_boundaries(gti) assert np.allclose(newg, np.array([[1.16703354e08, 1.16703514e08]])) + def test_split_gtis_by_exposure_min_gti_sep(self): + gtis = [[0, 30], [86450, 86460]] + new_gtis = split_gtis_by_exposure(gtis, 400, new_interval_if_gti_sep=86400) + assert np.allclose(new_gtis[0], [[0, 30]]) + assert np.allclose(new_gtis[1], [[86450, 86460]]) + + def test_split_gtis_by_exposure_no_min_gti_sep(self): + gtis = [[0, 30], [86440, 86470], [86490, 86520], [86530, 86560]] + new_gtis = split_gtis_by_exposure(gtis, 60, new_interval_if_gti_sep=None) + assert np.allclose(new_gtis[0], [[0, 30], [86440, 86470]]) + assert np.allclose(new_gtis[1], [[86490, 86520], [86530, 86560]]) + + def test_split_gtis_by_exposure_small_exp(self): + gtis = [[0, 30], [86440, 86470], [86490, 86495], [86500, 86505]] + new_gtis = split_gtis_by_exposure(gtis, 15, new_interval_if_gti_sep=None) + assert np.allclose( + new_gtis[:4], + [ + [[0, 15]], + [[15, 30]], + [[86440, 86455]], + [[86455, 86470]], + ], + ) + assert np.allclose(new_gtis[4], [[86490, 86495], [86500, 86505]]) + + def test_split_gtis_at_indices(self): + gtis = [[0, 30], [50, 60], [80, 90]] + new_gtis = split_gtis_at_indices(gtis, 1) + assert np.allclose(new_gtis[0], [[0, 30]]) + assert np.allclose(new_gtis[1], [[50, 60], [80, 90]]) + _ALL_METHODS = ["intersection", "union", "infer", "append"] diff --git a/stingray/tests/test_lightcurve.py b/stingray/tests/test_lightcurve.py index 8dc528b2c..dbbc8c4e9 100644 --- a/stingray/tests/test_lightcurve.py +++ b/stingray/tests/test_lightcurve.py @@ -1849,7 +1849,7 @@ def test_split_lc_by_gtis_when_dt_is_array(self): frac_exp=frac_exp, ) - list_of_lcs = lc.split_by_gti() + list_of_lcs = lc.split_by_gti() lc0 = list_of_lcs[0] lc1 = list_of_lcs[1] assert np.allclose(lc0.time, [1, 2, 3, 5])