From f9afbee761c4ac1f4f03d958bfe9d4dff1c9bc02 Mon Sep 17 00:00:00 2001 From: Matteo Bachetti Date: Mon, 25 Sep 2023 12:48:51 +0200 Subject: [PATCH] Allow for complex values in flux iterables --- stingray/fourier.py | 8 ++++++-- stingray/tests/test_fourier.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/stingray/fourier.py b/stingray/fourier.py index 19dddc97e..a0f09a84b 100644 --- a/stingray/fourier.py +++ b/stingray/fourier.py @@ -1077,6 +1077,10 @@ def get_flux_iterable_from_segments(times, gti, segment_size, n_bin=None, fluxes binned = fluxes is not None if binned: dt = np.median(np.diff(times[:100])) + cast_kind = float + fluxes = np.asarray(fluxes) + if np.iscomplexobj(fluxes): + cast_kind = complex fun = _which_segment_idx_fun(binned, dt) @@ -1093,9 +1097,9 @@ def get_flux_iterable_from_segments(times, gti, segment_size, n_bin=None, fluxes ).astype(float) cts = np.array(cts) else: - cts = fluxes[idx0:idx1].astype(float) + cts = fluxes[idx0:idx1].astype(cast_kind) if errors is not None: - cts = cts, errors[idx0:idx1] + cts = cts, errors[idx0:idx1].astype(cast_kind) yield cts diff --git a/stingray/tests/test_fourier.py b/stingray/tests/test_fourier.py index 480c9c2ba..3c4ca5e70 100644 --- a/stingray/tests/test_fourier.py +++ b/stingray/tests/test_fourier.py @@ -48,6 +48,23 @@ def test_norm(): assert np.isclose(pdsfrac[good].mean(), pois_frac, rtol=0.01) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex64, np.complex128]) +def test_flux_iterables(dtype): + times = np.arange(4) + fluxes = np.ones(4).astype(dtype) + errors = np.ones(4).astype(dtype) * np.sqrt(2) + gti = np.asarray([[-0.5, 3.5]]) + iter = get_flux_iterable_from_segments(times, gti, 2, n_bin=None, fluxes=fluxes, errors=errors) + cast_kind = float + if np.iscomplexobj(fluxes): + cast_kind = complex + for it, er in iter: + assert np.allclose(it, 1, rtol=0.01) + assert np.allclose(er, np.sqrt(2), rtol=0.01) + assert isinstance(it[0], cast_kind) + assert isinstance(er[0], cast_kind) + + class TestCoherence(object): @classmethod def setup_class(cls):