diff --git a/docs/changes/782.feature.rst b/docs/changes/782.feature.rst new file mode 100644 index 000000000..1ccafdc8b --- /dev/null +++ b/docs/changes/782.feature.rst @@ -0,0 +1 @@ +Add function to randomize data in small bad time intervals diff --git a/docs/index.rst b/docs/index.rst index c6ae9caf1..1e93d4298 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -30,6 +30,7 @@ Current Capabilities * simulating a light curve from another light curve and a 1-d (time) or 2-d (time-energy) impulse response * simulating an event list from a given light curve _and_ with a given energy spectrum * Good Time Interval operations +* Filling gaps in light curves with statistically sound fake data 2. Fourier methods ~~~~~~~~~~~~~~~~~~ diff --git a/stingray/base.py b/stingray/base.py index 4e76cba6d..9b0cc62f5 100644 --- a/stingray/base.py +++ b/stingray/base.py @@ -15,8 +15,27 @@ from astropy.time import Time, TimeDelta from astropy.units import Quantity -from stingray.utils import sqsum from .io import _can_save_longdouble, _can_serialize_meta +from .utils import ( + sqsum, + assign_value_if_none, + make_nd_into_arrays, + make_1d_arrays_into_nd, + get_random_state, + find_nearest, + rebin_data, +) +from .gti import ( + create_gti_mask, + check_gtis, + cross_two_gtis, + join_gtis, + gti_border_bins, + get_btis, + merge_gtis, + check_separate, + append_gtis, +) from typing import TYPE_CHECKING, Type, TypeVar, Union @@ -478,7 +497,6 @@ def to_pandas(self) -> DataFrame: """ from pandas import DataFrame - from .utils import make_nd_into_arrays data = {} array_attrs = self.array_attrs() + [self.main_array_attr] + self.internal_array_attrs() @@ -515,7 +533,6 @@ def from_pandas(cls: Type[Tso], ts: DataFrame) -> Tso: """ import re - from .utils import make_1d_arrays_into_nd cls = cls() @@ -1023,7 +1040,6 @@ def __getitem__(self, index): ts_new : :class:`StingrayObject` object The new :class:`StingrayObject` object with the set of selected data. """ - from .utils import assign_value_if_none if isinstance(index, (int, np.integer)): start = index @@ -1077,7 +1093,7 @@ class StingrayTimeseries(StingrayObject): dt: float The time resolution of the time series. Can be a scalar or an array attribute (useful - for non-uniformly sampled data or events from different instruments) + for non-evenly sampled data or events from different instruments) mjdref : float The MJD used as a reference for the time array. @@ -1112,7 +1128,7 @@ class StingrayTimeseries(StingrayObject): dt: float The time resolution of the measurements. Can be a scalar or an array attribute (useful - for non-uniformly sampled data or events from different instruments) + for non-evenly sampled data or events from different instruments) mjdref : float The MJD used as a reference for the time array. @@ -1200,8 +1216,6 @@ def gti(self, value): @property def mask(self): - from .gti import create_gti_mask - if self._mask is None: self._mask = create_gti_mask(self.time, self.gti, dt=self.dt) return self._mask @@ -1292,7 +1306,6 @@ def apply_gtis(self, new_gti=None, inplace: bool = True): """ # I import here to avoid the risk of circular imports - from .gti import check_gtis, create_gti_mask if new_gti is None: new_gti = self.gti @@ -1324,7 +1337,6 @@ def split_by_gti(self, gti=None, min_points=2): list_of_tss : list A list of :class:`StingrayTimeseries` objects, one for each GTI segment """ - from .gti import gti_border_bins, create_gti_mask if gti is None: gti = self.gti @@ -1530,8 +1542,6 @@ def _operation_with_other_obj( other = other.change_mjdref(self.mjdref) if not np.array_equal(self.gti, other.gti): - from .gti import cross_two_gtis - warnings.warn( "The good time intervals in the two time series are different. Data outside the " "common GTIs will be discarded." @@ -1634,8 +1644,6 @@ def __getitem__(self, index): >>> assert np.allclose(ts[2].counts, [33]) >>> assert np.allclose(ts[:2].counts, [11, 22]) """ - from .utils import assign_value_if_none - from .gti import cross_two_gtis new_ts = super().__getitem__(index) step = 1 @@ -1721,7 +1729,6 @@ def truncate(self, start=0, stop=None, method="index"): def _truncate_by_index(self, start, stop): """Private method for truncation using index values.""" - from .gti import cross_two_gtis new_ts = self.apply_mask(slice(start, stop)) @@ -1835,7 +1842,6 @@ def _join_timeseries(self, others, strategy="intersection", ignore_meta=[]): `ts_new` : :class:`StingrayTimeseries` object The resulting :class:`StingrayTimeseries` object. """ - from .gti import check_separate, cross_gtis, append_gtis new_ts = type(self)() @@ -1869,8 +1875,6 @@ def _join_timeseries(self, others, strategy="intersection", ignore_meta=[]): all_objs = [self] + others - from .gti import merge_gtis - # Check if none of the GTIs was already initialized. all_gti = [obj._gti for obj in all_objs if obj._gti is not None] @@ -2039,7 +2043,6 @@ def rebin(self, dt_new=None, f=None, method="sum"): ts_new: :class:`StingrayTimeseries` object The :class:`StingrayTimeseries` object with the new, binned time series. """ - from .utils import rebin_data if f is None and dt_new is None: raise ValueError("You need to specify at least one between f and " "dt_new") @@ -2134,6 +2137,162 @@ def sort(self, reverse=False, inplace=False): mask = mask[::-1] return self.apply_mask(mask, inplace=inplace) + def fill_bad_time_intervals( + self, + max_length=None, + attrs_to_randomize=None, + buffer_size=None, + even_sampling=None, + seed=None, + ): + """Fill short bad time intervals with random data. + + .. warning:: + This method is only appropriate for *very short* bad time intervals. The simulated data + are basically white noise, so they are able to alter the statistical properties of + variable data. For very short gaps in the data, the effect of these small + injections of white noise should be negligible. How short depends on the single case, + the user is urged not to use the method as a black box and make simulations to measure + its effect. If you have long bad time intervals, you should use more advanced + techniques, not currently available in Stingray for this use case, such as Gaussian + Processes. In particular, please verify that the values of ``max_length`` and + ``buffer_size`` are adequate to your case. + + To fill the gaps in all but the time points (i.e., flux measures, energies), we take the + ``buffer_size`` (by default, the largest value between 100 and the estimated samples in + a ``max_length``-long gap) valid data points closest to the gap and repeat them randomly + with the same empirical statistical distribution. So, if the `my_fancy_attr` attribute, in + the 100 points of the buffer, has 30 times 10, 10 times 9, and 60 times 11, there will be + *on average* 30% of 10, 60% of 11, and 10% of 9 in the simulated data. + + Times are treated differently depending on the fact that the time series is evenly + sampled or not. If it is not, the times are simulated from a uniform distribution with the + same count rate found in the buffer. Otherwise, times just follow the same grid used + inside GTIs. Using the evenly sampled or not is decided based on the ``even_sampling`` + parameter. If left to ``None``, the time series is considered evenly sampled if + ``self.dt`` is greater than zero and the median separation between subsequent times is + within 1% of the time resolution. + + Other Parameters + ---------------- + max_length : float + Maximum length of a bad time interval to be filled. If None, the criterion is bad + time intervals shorter than 1/100th of the longest good time interval. + attrs_to_randomize : list of str, default None + List of array_attrs to randomize. ``If None``, all array_attrs are randomized. + It should not include ``time`` and ``_mask``, which are treated separately. + buffer_size : int, default 100 + Number of good data points to use to calculate the means and variance the random data + on each side of the bad time interval + even_sampling : bool, default None + Force the treatment of the data as evenly sampled or not. If None, the data are + considered evenly sampled if ``self.dt`` is larger than zero and the median + separation between subsequent times is within 1% of ``self.dt``. + seed : int, default None + Random seed to use for the simulation. If None, a random seed is generated. + + """ + + rs = get_random_state(seed) + + if attrs_to_randomize is None: + attrs_to_randomize = self.array_attrs() + self.internal_array_attrs() + for attr in ["time", "_mask"]: + if attr in attrs_to_randomize: + attrs_to_randomize.remove(attr) + + attrs_to_leave_alone = [ + a + for a in self.array_attrs() + self.internal_array_attrs() + if a not in attrs_to_randomize + ] + + if max_length is None: + max_length = np.max(self.gti[:, 1] - self.gti[:, 0]) / 100 + + btis = get_btis(self.gti, self.time[0], self.time[-1]) + if len(btis) == 0: + logging.info("No bad time intervals to fill") + return copy.deepcopy(self) + filtered_times = self.time[self.mask] + + new_times = [filtered_times.copy()] + new_attrs = {} + mean_data_separation = np.median(np.diff(filtered_times)) + if even_sampling is None: + # The time series is considered evenly sampled if the median separation between + # subsequent times is within 1% of the time resolution + even_sampling = False + if self.dt > 0 and np.isclose(mean_data_separation, self.dt, rtol=0.01): + even_sampling = True + logging.info(f"Data are {'not' if not even_sampling else ''} evenly sampled") + + if even_sampling: + est_samples_in_gap = int(max_length / self.dt) + else: + est_samples_in_gap = int(max_length / mean_data_separation) + + if buffer_size is None: + buffer_size = max(100, est_samples_in_gap) + + added_gtis = [] + + total_filled_time = 0 + for bti in btis: + length = bti[1] - bti[0] + if length > max_length: + continue + logging.info(f"Filling bad time interval {bti} ({length:.4f} s)") + epsilon = 1e-5 * length + added_gtis.append([bti[0] - epsilon, bti[1] + epsilon]) + filt_low_t, filt_low_idx = find_nearest(filtered_times, bti[0]) + filt_hig_t, filt_hig_idx = find_nearest(filtered_times, bti[1], side="right") + if even_sampling: + local_new_times = np.arange(bti[0] + self.dt / 2, bti[1], self.dt) + nevents = local_new_times.size + else: + low_time_arr = filtered_times[max(filt_low_idx - buffer_size, 0) : filt_low_idx] + high_time_arr = filtered_times[filt_hig_idx : buffer_size + filt_hig_idx] + + ctrate = ( + np.count_nonzero(low_time_arr) / (filt_low_t - low_time_arr[0]) + + np.count_nonzero(high_time_arr) / (high_time_arr[-1] - filt_hig_t) + ) / 2 + nevents = rs.poisson(ctrate * (bti[1] - bti[0])) + local_new_times = rs.uniform(bti[0], bti[1], nevents) + new_times.append(local_new_times) + + for attr in attrs_to_randomize: + low_arr = getattr(self, attr)[max(buffer_size - filt_low_idx, 0) : filt_low_idx] + high_arr = getattr(self, attr)[filt_hig_idx : buffer_size + filt_hig_idx] + if attr not in new_attrs: + new_attrs[attr] = [getattr(self, attr)[self.mask]] + new_attrs[attr].append(rs.choice(np.concatenate([low_arr, high_arr]), nevents)) + for attr in attrs_to_leave_alone: + if attr not in new_attrs: + new_attrs[attr] = [getattr(self, attr)[self.mask]] + if attr == "_mask": + new_attrs[attr].append(np.ones(nevents, dtype=bool)) + else: + new_attrs[attr].append(np.zeros(nevents) + np.nan) + total_filled_time += length + + logging.info(f"A total of {total_filled_time} s of data were simulated") + + new_gtis = join_gtis(self.gti, added_gtis) + new_times = np.concatenate(new_times) + order = np.argsort(new_times) + new_obj = type(self)() + new_obj.time = new_times[order] + + for attr in self.meta_attrs(): + setattr(new_obj, attr, getattr(self, attr)) + + for attr, values in new_attrs.items(): + setattr(new_obj, attr, np.concatenate(values)[order]) + new_obj.gti = new_gtis + return new_obj + def plot( self, attr, @@ -2145,6 +2304,7 @@ def plot( save=False, filename=None, plot_btis=True, + axis_limits=None, ): """ Plot the time series using ``matplotlib``. @@ -2168,6 +2328,10 @@ def plot( could be ``['Time (s)', 'Counts (s^-1)']`` ax : ``matplotlib.pyplot.axis`` object Axis to be used for plotting. Defaults to creating a new one. + axis_limits : list, tuple, string, default ``None`` + Parameter to set axis properties of the ``matplotlib`` figure. For example + it can be a list like ``[xmin, xmax, ymin, ymax]`` or any other + acceptable argument for the``matplotlib.pyplot.axis()`` method. title : str, default ``None`` The title of the plot. marker : str, default '-' @@ -2182,17 +2346,23 @@ def plot( Plot the bad time intervals as red areas on the plot """ import matplotlib.pyplot as plt - from .gti import get_btis if ax is None: plt.figure(attr) ax = plt.gca() - if labels is None: + valid_labels = (isinstance(labels, Iterable) and not isinstance(labels, str)) and len( + labels + ) == 2 + if labels is not None and not valid_labels: + warnings.warn("``labels`` must be an iterable with two labels for x and y axes.") + + if labels is None or not valid_labels: labels = ["Time (s)"] + [attr] - ylabel = labels[1] xlabel = labels[0] + ylabel = labels[1] + # Default values for labels ax.plot(self.time, getattr(self, attr), marker, ds="steps-mid", label=attr, zorder=10) @@ -2202,11 +2372,15 @@ def plot( getattr(self, attr), yerr=getattr(self, attr + "_err"), fmt="o", + zorder=10, ) ax.set_ylabel(ylabel) ax.set_xlabel(xlabel) + if axis_limits is not None: + ax.set_xlim(axis_limits[0], axis_limits[1]) + ax.set_ylim(axis_limits[2], axis_limits[3]) if title is not None: ax.set_title(title) @@ -2221,7 +2395,14 @@ def plot( tend = max(self.time[-1] + self.dt / 2, self.gti[-1, 1]) btis = get_btis(self.gti, tstart, tend) for bti in btis: - plt.axvspan(bti[0], bti[1], alpha=0.5, color="r", zorder=10) + plt.axvspan( + bti[0], + bti[1], + alpha=0.5, + facecolor="r", + zorder=1, + edgecolor="none", + ) return ax diff --git a/stingray/lightcurve.py b/stingray/lightcurve.py index 8602482ef..f53933387 100644 --- a/stingray/lightcurve.py +++ b/stingray/lightcurve.py @@ -1608,11 +1608,14 @@ def plot( self, witherrors=False, labels=None, - axis=None, + ax=None, title=None, marker="-", save=False, filename=None, + axis_limits=None, + axis=None, + plot_btis=True, ): """ Plot the light curve using ``matplotlib``. @@ -1629,11 +1632,14 @@ def plot( labels : iterable, default ``None`` A list of tuple with ``xlabel`` and ``ylabel`` as strings. - axis : list, tuple, string, default ``None`` + axis_limits : list, tuple, string, default ``None`` Parameter to set axis properties of the ``matplotlib`` figure. For example it can be a list like ``[xmin, xmax, ymin, ymax]`` or any other acceptable argument for the``matplotlib.pyplot.axis()`` method. + axis : list, tuple, string, default ``None`` + Deprecated in favor of ``axis_limits``, same functionality. + title : str, default ``None`` The title of the plot. @@ -1647,39 +1653,37 @@ def plot( filename : str File name of the image to save. Depends on the boolean ``save``. - """ - fig = plt.figure() - if witherrors: - fig = plt.errorbar(self.time, self.counts, yerr=self.counts_err, fmt=marker) - else: - fig = plt.plot(self.time, self.counts, marker) - - if labels is not None: - try: - plt.xlabel(labels[0]) - plt.ylabel(labels[1]) - except TypeError: - utils.simon("``labels`` must be either a list or tuple with " "x and y labels.") - raise - except IndexError: - utils.simon("``labels`` must have two labels for x and y " "axes.") - # Not raising here because in case of len(labels)==1, only - # x-axis will be labelled. + ax : ``matplotlib.pyplot.axis`` object + Axis to be used for plotting. Defaults to creating a new one. + plot_btis : bool + Plot the bad time intervals as red areas on the plot + """ if axis is not None: - plt.axis(axis) - - if title is not None: - plt.title(title) - - if save: - if filename is None: - plt.savefig("out.png") - else: - plt.savefig(filename) - else: - plt.show(block=False) + warnings.warn( + "The ``axis`` argument is deprecated in favor of ``axis_limits``. " + "Please use that instead.", + DeprecationWarning, + ) + axis_limits = axis + + flux_attr = "counts" + if not self.input_counts: + flux_attr = "countrate" + + return super().plot( + flux_attr, + witherrors=witherrors, + labels=labels, + ax=ax, + title=title, + marker=marker, + save=save, + filename=filename, + plot_btis=plot_btis, + axis_limits=axis_limits, + ) @classmethod def read( diff --git a/stingray/pulse/overlapandsave/ols.py b/stingray/pulse/overlapandsave/ols.py index 831372bd0..d6ac909b2 100644 --- a/stingray/pulse/overlapandsave/ols.py +++ b/stingray/pulse/overlapandsave/ols.py @@ -196,7 +196,7 @@ def prepareh(h, nfft: List[int], rfftn=None): The FFT-transformed, conjugate filter array """ rfftn = rfftn or np.fft.rfftn - return np.conj(rfftn(flip(np.conj(h)), nfft)) + return np.conj(rfftn(flip(np.conj(h)), nfft, axes=np.arange(len(nfft)))) def slice2range(s: slice): @@ -365,7 +365,9 @@ def olsStep( for (start, length, nh, border) in zip(starts, lengths, nh, border) ) xpart = padEdges(x, slices, mode=mode, **kwargs) - output = irfftn(rfftn(xpart, nfft) * hfftconj, nfft) + output = irfftn( + rfftn(xpart, nfft, axes=np.arange(len(nfft))) * hfftconj, nfft, axes=np.arange(len(nfft)) + ) return output[tuple(slice(0, s) for s in lengths)] diff --git a/stingray/tests/test_base.py b/stingray/tests/test_base.py index bbeb1ae04..56624dc5a 100644 --- a/stingray/tests/test_base.py +++ b/stingray/tests/test_base.py @@ -1291,3 +1291,133 @@ def test_join_ignore_attr(self): assert np.allclose(ts_new.time, [1, 2, 3, 4, 5, 7]) assert not hasattr(ts_new, "instr") assert ts_new.mission == (1, 2) + + +class TestFillBTI(object): + @classmethod + def setup_class(cls): + cls.rand_time = np.sort(np.random.uniform(0, 1000, 100000)) + cls.rand_ener = np.random.uniform(0, 100, 100000) + cls.gti = [[0, 900], [950, 1000]] + blablas = np.random.normal(0, 1, cls.rand_ener.size) + cls.ev_like = StingrayTimeseries( + cls.rand_time, energy=cls.rand_ener, blablas=blablas, gti=cls.gti + ) + time_edges = np.linspace(0, 1000, 1001) + counts = np.histogram(cls.rand_time, bins=time_edges)[0] + blablas = np.random.normal(0, 1, 1000) + cls.lc_like = StingrayTimeseries( + time=time_edges[:-1] + 0.5, counts=counts, blablas=blablas, gti=cls.gti, dt=1 + ) + + def test_no_btis_returns_copy(self): + ts = StingrayTimeseries([1, 2, 3], energy=[4, 6, 8], gti=[[0.5, 3.5]]) + ts_new = ts.fill_bad_time_intervals() + assert ts == ts_new + + def test_event_like(self): + ev_like_filt = copy.deepcopy(self.ev_like) + # I introduce a small gap in the GTIs + ev_like_filt.gti = np.asarray([[0, 498], [500, 520], [522, 700], [702, 900], [950, 1000]]) + ev_new = ev_like_filt.fill_bad_time_intervals() + + assert np.allclose(ev_new.gti, self.gti) + + # Now, I set the same GTIs as the original event list, and the data + # should be the same + ev_new.gti = ev_like_filt.gti + + new_masked, filt_masked = ev_new.apply_gtis(), ev_like_filt.apply_gtis() + for attr in ["time", "energy", "blablas"]: + assert np.allclose(getattr(new_masked, attr), getattr(filt_masked, attr)) + + def test_lc_like(self): + lc_like_filt = copy.deepcopy(self.lc_like) + # I introduce a small gap in the GTIs + lc_like_filt.gti = np.asarray([[0, 498], [500, 520], [522, 700], [702, 900], [950, 1000]]) + lc_new = lc_like_filt.fill_bad_time_intervals() + assert np.allclose(lc_new.gti, self.gti) + + lc_like_gtifilt = self.lc_like.apply_gtis(inplace=False) + # In this case, the time array should also be the same as the original + assert np.allclose(lc_new.time, lc_like_gtifilt.time) + + # Now, I set the same GTIs as the original event list, and the data + # should be the same + lc_new.gti = lc_like_filt.gti + + new_masked, filt_masked = lc_new.apply_gtis(), lc_like_filt.apply_gtis() + for attr in ["time", "counts", "blablas"]: + assert np.allclose(getattr(new_masked, attr), getattr(filt_masked, attr)) + + def test_ignore_attrs_ev_like(self): + ev_like_filt = copy.deepcopy(self.ev_like) + # I introduce a small gap in the GTIs + ev_like_filt.gti = np.asarray([[0, 498], [500, 900], [950, 1000]]) + ev_new0 = ev_like_filt.fill_bad_time_intervals(seed=1234) + ev_new1 = ev_like_filt.fill_bad_time_intervals(seed=1234, attrs_to_randomize=["energy"]) + assert np.allclose(ev_new0.gti, ev_new1.gti) + assert np.allclose(ev_new0.time, ev_new1.time) + + assert np.count_nonzero(np.isnan(ev_new0.blablas)) == 0 + assert np.count_nonzero(np.isnan(ev_new1.blablas)) > 0 + assert np.count_nonzero(np.isnan(ev_new1.energy)) == 0 + + def test_ignore_attrs_lc_like(self): + lc_like_filt = copy.deepcopy(self.lc_like) + # I introduce a small gap in the GTIs + lc_like_filt.gti = np.asarray([[0, 498], [500, 900], [950, 1000]]) + lc_new0 = lc_like_filt.fill_bad_time_intervals(seed=1234) + lc_new1 = lc_like_filt.fill_bad_time_intervals(seed=1234, attrs_to_randomize=["counts"]) + assert np.allclose(lc_new0.gti, lc_new1.gti) + assert np.allclose(lc_new0.time, lc_new1.time) + + assert np.count_nonzero(np.isnan(lc_new0.blablas)) == 0 + assert np.count_nonzero(np.isnan(lc_new1.blablas)) > 0 + assert np.count_nonzero(np.isnan(lc_new1.counts)) == 0 + + def test_forcing_non_uniform(self): + ev_like_filt = copy.deepcopy(self.ev_like) + # I introduce a small gap in the GTIs + ev_like_filt.gti = np.asarray([[0, 498], [500, 900], [950, 1000]]) + # Results should be exactly the same + ev_new0 = ev_like_filt.fill_bad_time_intervals(even_sampling=False, seed=201903) + ev_new1 = ev_like_filt.fill_bad_time_intervals(even_sampling=None, seed=201903) + for attr in ["time", "energy"]: + assert np.allclose(getattr(ev_new0, attr), getattr(ev_new1, attr)) + + def test_forcing_uniform(self): + lc_like_filt = copy.deepcopy(self.lc_like) + # I introduce a small gap in the GTIs + lc_like_filt.gti = np.asarray([[0, 498], [500, 900], [950, 1000]]) + # Results should be exactly the same + lc_new0 = lc_like_filt.fill_bad_time_intervals(even_sampling=True, seed=201903) + lc_new1 = lc_like_filt.fill_bad_time_intervals(even_sampling=None, seed=201903) + for attr in ["time", "counts", "blablas"]: + assert np.allclose(getattr(lc_new0, attr), getattr(lc_new1, attr)) + + def test_bti_close_to_edge_event_like(self): + ev_like_filt = copy.deepcopy(self.ev_like) + # I introduce a small gap in the GTIs + ev_like_filt.gti = np.asarray([[0, 0.5], [1, 900], [950, 1000]]) + ev_new = ev_like_filt.fill_bad_time_intervals() + assert np.allclose(ev_new.gti, self.gti) + + ev_like_filt = copy.deepcopy(self.ev_like) + # I introduce a small gap in the GTIs + ev_like_filt.gti = np.asarray([[0, 900], [950, 999], [999.5, 1000]]) + ev_new = ev_like_filt.fill_bad_time_intervals() + assert np.allclose(ev_new.gti, self.gti) + + def test_bti_close_to_edge_lc_like(self): + lc_like_filt = copy.deepcopy(self.lc_like) + # I introduce a small gap in the GTIs + lc_like_filt.gti = np.asarray([[0, 0.5], [1, 900], [950, 1000]]) + lc_new = lc_like_filt.fill_bad_time_intervals() + assert np.allclose(lc_new.gti, self.gti) + + lc_like_filt = copy.deepcopy(self.lc_like) + # I introduce a small gap in the GTIs + lc_like_filt.gti = np.asarray([[0, 900], [950, 999], [999.5, 1000]]) + lc_new = lc_like_filt.fill_bad_time_intervals() + assert np.allclose(lc_new.gti, self.gti) diff --git a/stingray/tests/test_lightcurve.py b/stingray/tests/test_lightcurve.py index 68e03d71d..6b770bbd3 100644 --- a/stingray/tests/test_lightcurve.py +++ b/stingray/tests/test_lightcurve.py @@ -1148,9 +1148,10 @@ def test_plot_simple(self): def test_plot_wrong_label_type(self): lc = Lightcurve(self.times, self.counts) - with pytest.raises(TypeError): - with pytest.warns(UserWarning, match="must be either a list or tuple") as w: - lc.plot(labels=123) + with pytest.warns( + UserWarning, match="``labels`` must be an iterable with two labels " + ) as w: + lc.plot(labels=123) plt.close("all") def test_plot_labels_index_error(self): @@ -1158,7 +1159,9 @@ def test_plot_labels_index_error(self): with pytest.warns(UserWarning) as w: lc.plot(labels=("x")) - assert np.any(["must have two labels" in str(wi.message) for wi in w]) + assert np.any( + ["``labels`` must be an iterable with two labels " in str(wi.message) for wi in w] + ) plt.close("all") def test_plot_default_filename(self): @@ -1175,14 +1178,19 @@ def test_plot_custom_filename(self): os.unlink("lc.png") plt.close("all") - def test_plot_axis(self): + def test_plot_axis_arg(self): lc = Lightcurve(self.times, self.counts) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=UserWarning) + with pytest.warns(DeprecationWarning, match="argument is deprecated in favor"): lc.plot(axis=[0, 1, 0, 100]) assert plt.fignum_exists(1) plt.close("all") + def test_plot_axis_limits_arg(self): + lc = Lightcurve(self.times, self.counts) + lc.plot(axis_limits=[0, 1, 0, 100]) + assert plt.fignum_exists(1) + plt.close("all") + def test_plot_title(self): lc = Lightcurve(self.times, self.counts) with warnings.catch_warnings(): diff --git a/stingray/utils.py b/stingray/utils.py index ce0adb298..7f0b16312 100644 --- a/stingray/utils.py +++ b/stingray/utils.py @@ -1299,7 +1299,7 @@ def nearest_power_of_two(x): return x_nearest -def find_nearest(array, value): +def find_nearest(array, value, side="left"): """ Return the array value that is closest to the input value (Abigail Stevens: Thanks StackOverflow!) @@ -1313,6 +1313,11 @@ def find_nearest(array, value): value : int or float The value you want to find the closest to in the array. + Other Parameters + ---------------- + side : str + Look at the ``numpy.searchsorted`` documentation for more information. + Returns ------- array[idx] : int or float @@ -1322,7 +1327,7 @@ def find_nearest(array, value): The index of the array of the closest value. """ - idx = np.searchsorted(array, value, side="left") + idx = np.searchsorted(array, value, side=side) if idx == len(array) or np.fabs(value - array[idx - 1]) < np.fabs(value - array[idx]): return array[idx - 1], idx - 1 else: