Skip to content

Commit

Permalink
Merge pull request #647 from mihirtripathi97/bugfix-for-issue592
Browse files Browse the repository at this point in the history
Bugfix for issue- CrossCorrelation Gives Strange Results (#592)
  • Loading branch information
matteobachetti authored Mar 18, 2022
2 parents a6dec24 + 33dee76 commit cde201a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 6 deletions.
24 changes: 20 additions & 4 deletions stingray/crosscorrelation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class CrossCorrelation(object):
References
----------
.. [scipy-docs] https://docs.scipy.org/doc/scipy-0.19.0/reference/generated/scipy.signal.correlate.html
"""

def __init__(self, lc1=None, lc2=None, cross=None, mode='same', norm="none"):
Expand Down Expand Up @@ -211,6 +210,12 @@ def cal_timeshift(self, dt=1.0):
"""
Calculate the cross correlation against all possible time lags, both positive and negative.
The method signal.correlation_lags() uses SciPy versions >= 1.6.1 ([scipy-docs-lag]_)
References
----------
.. [scipy-docs-lag] https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.correlation_lags.html
Parameters
----------
dt: float, optional, default ``1.0``
Expand Down Expand Up @@ -244,9 +249,20 @@ def cal_timeshift(self, dt=1.0):
self._make_corr(self.lc1, self.lc2)

self.n = len(self.corr)
dur = int(self.n / 2)
# Correlation against all possible lags, positive as well as negative lags are stored
x_lags = np.linspace(-dur, dur, self.n)

if self.cross is not None:
# Obtains correlation lags if a cross spectrum object is given
# Correlation against all possible lags, positive as well as negative lags are stored
# signal.correlation_lags() method uses SciPy versions >= 1.6.1
x_lags = signal.correlation_lags(self.n, self.n, self.mode)

else:
# Obtains correlation lags if two light curves are porvided
# Correlation against all possible lags, positive as well as negative lags are stored
# signal.correlation_lags() method uses SciPy versions >= 1.6.1
x_lags = \
signal.correlation_lags(np.size(self.lc1.counts), np.size(self.lc2.counts), self.mode)

self.time_lags = x_lags * self.dt
# time_shift is the time lag for max. correlation
self.time_shift = self.time_lags[np.argmax(self.corr)]
Expand Down
34 changes: 32 additions & 2 deletions stingray/tests/test_crosscorrelation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def setup_class(cls):
cls.lc_s = Lightcurve([1, 2, 3], [5, 3, 2])
# lc with different time resolution
cls.lc_u = Lightcurve([1, 3, 5, 7, 9], [4, 8, 1, 9, 11])
# Light curve with odd number of data points
cls.lc_odd = Lightcurve([1, 2, 3, 4, 5], [2, 3, 2, 4, 1])
# Light curve with even number of data points
cls.lc_even = Lightcurve([1, 2, 3, 4, 5, 6], [2, 3, 2, 4, 1, 3])

def test_empty_cross_correlation(self):
cr = CrossCorrelation()
Expand Down Expand Up @@ -117,15 +121,15 @@ def test_crossparam_input(self):

def test_cross_correlation_with_unequal_lc(self):
result = np.array([-0.66666667, -0.33333333, -1., 0.66666667, 3.13333333])
lags_result = np.array([-2, -1, 0, 1, 2])
lags_result = np.array([-1., 0., 1., 2., 3.])
cr = CrossCorrelation(self.lc1, self.lc_s)
assert np.allclose(cr.lc1, self.lc1)
assert np.allclose(cr.lc2, self.lc_s)
assert np.allclose(cr.corr, result)
assert np.isclose(cr.dt, self.lc1.dt)
assert cr.n == 5
assert np.isclose(cr.time_shift, 3.0)
assert np.allclose(cr.time_lags, lags_result)
assert np.isclose(cr.time_shift, 2.0)
assert cr.mode == 'same'
assert cr.auto is False

Expand Down Expand Up @@ -252,3 +256,29 @@ def test_auto_correlation_with_full_mode(self):
assert np.isclose(ac.time_shift, 0.0)
assert ac.mode == 'full'
assert ac.auto is True

def test_cross_correlation_with_identical_lc_oddlength(self):
result = np.array([ 1.68, -3.36, 5.2, -3.36, 1.68])
lags_result = np.array([-2, -1, 0, 1, 2])
cr = CrossCorrelation(self.lc_odd,self.lc_odd)
assert np.allclose(cr.lc1, cr.lc2)
assert np.allclose(cr.corr, result)
assert np.isclose(cr.dt, self.lc_odd.dt)
assert cr.n == 5
assert np.allclose(cr.time_lags, lags_result)
assert np.isclose(cr.time_shift,0.0)
assert cr.mode == 'same'
assert cr.auto is False

def test_cross_correlation_with_identical_lc_evenlength(self):
result = np.array([-1.75, 2.5, -4.25, 5.5, -4.25, 2.5])
lags_result = np.array([-3, -2, -1, 0, 1, 2])
cr = CrossCorrelation(self.lc_even,self.lc_even)
assert np.allclose(cr.lc1, cr.lc2)
assert np.allclose(cr.corr, result)
assert np.isclose(cr.dt, self.lc_even.dt)
assert cr.n == 6
assert np.allclose(cr.time_lags, lags_result)
assert np.isclose(cr.time_shift,0.0)
assert cr.mode == 'same'
assert cr.auto is False

0 comments on commit cde201a

Please sign in to comment.