diff --git a/stingray/crosscorrelation.py b/stingray/crosscorrelation.py index bb42cc03d..8d5b196da 100644 --- a/stingray/crosscorrelation.py +++ b/stingray/crosscorrelation.py @@ -76,7 +76,6 @@ class CrossCorrelation(object): References ---------- .. [scipy-docs] https://docs.scipy.org/doc/scipy-0.19.0/reference/generated/scipy.signal.correlate.html - """ def __init__(self, lc1=None, lc2=None, cross=None, mode='same', norm="none"): @@ -211,6 +210,12 @@ def cal_timeshift(self, dt=1.0): """ Calculate the cross correlation against all possible time lags, both positive and negative. + The method signal.correlation_lags() uses SciPy versions >= 1.6.1 ([scipy-docs-lag]_) + + References + ---------- + .. [scipy-docs-lag] https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.correlation_lags.html + Parameters ---------- dt: float, optional, default ``1.0`` @@ -244,9 +249,20 @@ def cal_timeshift(self, dt=1.0): self._make_corr(self.lc1, self.lc2) self.n = len(self.corr) - dur = int(self.n / 2) - # Correlation against all possible lags, positive as well as negative lags are stored - x_lags = np.linspace(-dur, dur, self.n) + + if self.cross is not None: + # Obtains correlation lags if a cross spectrum object is given + # Correlation against all possible lags, positive as well as negative lags are stored + # signal.correlation_lags() method uses SciPy versions >= 1.6.1 + x_lags = signal.correlation_lags(self.n, self.n, self.mode) + + else: + # Obtains correlation lags if two light curves are porvided + # Correlation against all possible lags, positive as well as negative lags are stored + # signal.correlation_lags() method uses SciPy versions >= 1.6.1 + x_lags = \ + signal.correlation_lags(np.size(self.lc1.counts), np.size(self.lc2.counts), self.mode) + self.time_lags = x_lags * self.dt # time_shift is the time lag for max. correlation self.time_shift = self.time_lags[np.argmax(self.corr)] diff --git a/stingray/tests/test_crosscorrelation.py b/stingray/tests/test_crosscorrelation.py index 4625c2bcf..f9a1c4ef4 100644 --- a/stingray/tests/test_crosscorrelation.py +++ b/stingray/tests/test_crosscorrelation.py @@ -47,6 +47,10 @@ def setup_class(cls): cls.lc_s = Lightcurve([1, 2, 3], [5, 3, 2]) # lc with different time resolution cls.lc_u = Lightcurve([1, 3, 5, 7, 9], [4, 8, 1, 9, 11]) + # Light curve with odd number of data points + cls.lc_odd = Lightcurve([1, 2, 3, 4, 5], [2, 3, 2, 4, 1]) + # Light curve with even number of data points + cls.lc_even = Lightcurve([1, 2, 3, 4, 5, 6], [2, 3, 2, 4, 1, 3]) def test_empty_cross_correlation(self): cr = CrossCorrelation() @@ -117,15 +121,15 @@ def test_crossparam_input(self): def test_cross_correlation_with_unequal_lc(self): result = np.array([-0.66666667, -0.33333333, -1., 0.66666667, 3.13333333]) - lags_result = np.array([-2, -1, 0, 1, 2]) + lags_result = np.array([-1., 0., 1., 2., 3.]) cr = CrossCorrelation(self.lc1, self.lc_s) assert np.allclose(cr.lc1, self.lc1) assert np.allclose(cr.lc2, self.lc_s) assert np.allclose(cr.corr, result) assert np.isclose(cr.dt, self.lc1.dt) assert cr.n == 5 + assert np.isclose(cr.time_shift, 3.0) assert np.allclose(cr.time_lags, lags_result) - assert np.isclose(cr.time_shift, 2.0) assert cr.mode == 'same' assert cr.auto is False @@ -252,3 +256,29 @@ def test_auto_correlation_with_full_mode(self): assert np.isclose(ac.time_shift, 0.0) assert ac.mode == 'full' assert ac.auto is True + + def test_cross_correlation_with_identical_lc_oddlength(self): + result = np.array([ 1.68, -3.36, 5.2, -3.36, 1.68]) + lags_result = np.array([-2, -1, 0, 1, 2]) + cr = CrossCorrelation(self.lc_odd,self.lc_odd) + assert np.allclose(cr.lc1, cr.lc2) + assert np.allclose(cr.corr, result) + assert np.isclose(cr.dt, self.lc_odd.dt) + assert cr.n == 5 + assert np.allclose(cr.time_lags, lags_result) + assert np.isclose(cr.time_shift,0.0) + assert cr.mode == 'same' + assert cr.auto is False + + def test_cross_correlation_with_identical_lc_evenlength(self): + result = np.array([-1.75, 2.5, -4.25, 5.5, -4.25, 2.5]) + lags_result = np.array([-3, -2, -1, 0, 1, 2]) + cr = CrossCorrelation(self.lc_even,self.lc_even) + assert np.allclose(cr.lc1, cr.lc2) + assert np.allclose(cr.corr, result) + assert np.isclose(cr.dt, self.lc_even.dt) + assert cr.n == 6 + assert np.allclose(cr.time_lags, lags_result) + assert np.isclose(cr.time_shift,0.0) + assert cr.mode == 'same' + assert cr.auto is False