Skip to content

Commit

Permalink
Fix fftfit_classic against PRESTO
Browse files Browse the repository at this point in the history
  • Loading branch information
aarchiba committed Jan 10, 2021
1 parent 351a76a commit 8efe495
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 5 deletions.
5 changes: 3 additions & 2 deletions src/pint/profile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ def fftfit_cprof(template):
of that returned by ``np.fft.rfft``.
"""
tc = rfft(template)
return tc[0], np.abs(tc)[1:], -np.angle(tc)[1:]
tc *= np.exp(-2.j*np.pi*np.arange(len(tc))/len(template))
return 2*tc, 2*np.abs(tc)[1:], -np.angle(tc)[1:]


def fftfit_classic(profile, template_amplitudes, template_angles, code="aarchiba"):
Expand Down Expand Up @@ -270,7 +271,7 @@ def fftfit_classic(profile, template_amplitudes, template_angles, code="aarchiba
template = irfft(template_f)
r = fftfit_full(template, profile, code=code)

shift = (r.shift % 1) * len(profile) + 1
shift = (r.shift % 1) * len(profile)
eshift = r.uncertainty * len(profile)
snr = np.nan
esnr = np.nan
Expand Down
189 changes: 189 additions & 0 deletions src/pint/profile/fftfit_nustar.py~aside
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# by mateobachetti
# from https://github.com/NuSTAR/nustar-clock-utils/blob/master/nuclockutils/diagnostics/fftfit.py
import numpy as np
from scipy.optimize import minimize, brentq


def find_delay_with_ccf(amp, pha):
nh = 32
nprof = nh * 2
CCF = np.zeros(64, dtype=np.complex)
CCF[:nh] = amp[:nh] * np.cos(pha[:nh]) + 1.0j * amp[:nh] * np.sin(pha[:nh])
CCF[nprof : nprof - nh : -1] = np.conj(CCF[nprof : nprof - nh : -1])
CCF[nh // 2 : nh] = 0
CCF[nprof - nh // 2 : nprof - nh : -1] = 0
ccf = np.fft.ifft(CCF)

imax = np.argmax(ccf.real)
cmax = ccf[imax]
shift = normalize_phase_0d5(imax / nprof)

# plt.figure()
# plt.plot(ccf.real)
# plt.show()
# fb=np.real(cmax)
# ia=imax-1
# if(ia == -1): ia=nprof-1
# fa=np.real(ccf[ia])
# ic=imax+1
# if(ic == nprof): ic=0
# fc=np.real(ccf[ic])
# if ((2*fb-fc-fa) != 0):
# shift=imax+0.5*(fa-fc)/(2*fb-fc-fa)
# shift = normalize_phase_0d5(shift / nprof)
return shift


def best_phase_func(tau, amp, pha, ngood=20):
# tau = params['tau']
# good = slice(1, idx.size // 2 + 1)
good = slice(1, ngood + 1)
idx = np.arange(1, ngood + 1, dtype=int)
res = np.sum(idx * amp[good] * np.sin(-pha[good] + TWOPI * idx * tau))
# print(tau, res)
return res


TWOPI = 2 * np.pi


def chi_sq(b, tau, P, S, theta, phi, ngood=20):
# tau = params['tau']
# good = slice(1, idx.size // 2 + 1)
good = slice(1, ngood + 1)
idx = np.arange(1, ngood + 1, dtype=int)
angle_diff = phi[good] - theta[good] + TWOPI * idx * tau
exp_term = np.exp(1.0j * angle_diff)

to_square = P[good] - b * S[good] * exp_term
res = np.sum((to_square * to_square.conj()))

return res.real


def chi_sq_alt(b, tau, P, S, theta, phi, ngood=20):
# tau = params['tau']
# good = slice(1, idx.size // 2 + 1)
good = slice(1, ngood + 1)
idx = np.arange(1, ngood + 1, dtype=int)
angle_diff = phi[good] - theta[good] + TWOPI * idx * tau
chisq_1 = P[good] ** 2 + b ** 2 * S[good] ** 2
chisq_2 = -2 * b * P[good] * S[good] * np.cos(angle_diff)
res = np.sum(chisq_1 + chisq_2)

return res


def fftfit(prof, template):
"""Align a template to a pulse profile.
Parameters
----------
prof : array
The pulse profile
template : array, default None
The template of the pulse used to perform the TOA calculation. If None,
a simple sinusoid is used
Returns
-------
mean_amp, std_amp : floats
Mean and standard deviation of the amplitude
mean_phase, std_phase : floats
Mean and standard deviation of the phase
"""
prof = prof - np.mean(prof)

nbin = len(prof)

template = template - np.mean(template)

temp_ft = np.fft.fft(template)
prof_ft = np.fft.fft(prof)
freq = np.fft.fftfreq(prof.size)
good = freq == freq

P = np.abs(prof_ft[good])
theta = np.angle(prof_ft[good])
S = np.abs(temp_ft[good])
phi = np.angle(temp_ft[good])

assert np.allclose(temp_ft[good], S * np.exp(1.0j * phi))
assert np.allclose(prof_ft[good], P * np.exp(1.0j * theta))

amp = P * S
pha = theta - phi

mean = np.mean(amp)
ngood = np.count_nonzero(amp >= mean)

dph_ccf = find_delay_with_ccf(amp, pha)

idx = np.arange(0, len(P), dtype=int)
sigma = np.std(prof_ft[good])

def func_to_minimize(tau):
return best_phase_func(-tau, amp, pha, ngood=ngood)

start_val = dph_ccf
start_sign = np.sign(func_to_minimize(start_val))

count_down = 0
count_up = 0
trial_val_up = start_val
trial_val_down = start_val
while True:
if np.sign(func_to_minimize(trial_val_up)) != start_sign:
best_dph = trial_val_up
break
if np.sign(func_to_minimize(trial_val_down)) != start_sign:
best_dph = trial_val_down
break
trial_val_down -= 1 / nbin
count_down += 1
trial_val_up += 1 / nbin
count_up += 1

a, b = best_dph - 2 / nbin, best_dph + 2 / nbin

shift, res = brentq(func_to_minimize, a, b, full_output=True)

nmax = ngood
good = slice(1, nmax)

big_sum = np.sum(
idx[good] ** 2 * amp[good] * np.cos(-pha[good] + 2 * np.pi * idx[good] * -shift)
)

b = np.sum(
amp[good] * np.cos(-pha[good] + 2 * np.pi * idx[good] * -shift)
) / np.sum(S[good] ** 2)

eshift = sigma ** 2 / (2 * b * big_sum)

eb = sigma ** 2 / (2 * np.sum(S[good] ** 2))

return b, np.sqrt(eb), normalize_phase_0d5(shift), np.sqrt(eshift)


def normalize_phase_0d5(phase):
"""Normalize phase between -0.5 and 0.5
Examples
--------
>>> normalize_phase_0d5(0.5)
0.5
>>> normalize_phase_0d5(-0.5)
0.5
>>> normalize_phase_0d5(4.25)
0.25
>>> normalize_phase_0d5(-3.25)
-0.25
"""
while phase > 0.5:
phase -= 1
while phase <= -0.5:
phase += 1
return phase


def fftfit_basic(template, profile):
n, seb, shift, eshift = fftfit(profile, template)
return shift
7 changes: 4 additions & 3 deletions src/pint/profile/fftfit_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ def fftfit_full(template, profile):
raise ValueError(
"template has length {} which is too long".format(len(template))
)
_, amp, pha = pint.profile.fftfit_cprof(template)
#_, amp, pha = pint.profile.fftfit_cprof(template)
_, amp, pha = presto.fftfit.cprof(template)
shift, eshift, snr, esnr, b, errb, ngood = presto.fftfit.fftfit(
profile,
profile, amp, pha
)
r = pint.profile.FFTFITResult()
# Need to add 1 to the shift for some reason
r.shift = pint.profile.wrap((shift + 1) / len(template))
r.shift = pint.profile.wrap(shift / len(template))
r.uncertainty = eshift / len(template)
return r

Expand Down
39 changes: 39 additions & 0 deletions tests/test_fftfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
fftfit_full,
fftfit_nustar,
fftfit_presto,
fftfit_cprof,
fftfit_classic,
)

NO_PRESTO = fftfit_presto.presto is None
Expand Down Expand Up @@ -860,3 +862,40 @@ def value_within_one_sigma():

# Must be pessimistic
assert_happens_with_probability(value_within_one_sigma, ONE_SIGMA, p_upper=1)


@pytest.mark.parametrize("n", [32,128,1024])
def test_fftfit_cprof_compare(n):
template = vonmises_profile(100, n)
import presto.fftfit
c, amp, pha = fftfit_cprof(template)
cp, ampp, phap = presto.fftfit.cprof(template)

assert_allclose(c, cp, atol=5e-6)
assert_allclose(amp, ampp, atol=5e-6)
assert_allclose(cp[1:], ampp*np.exp(1.j*phap), atol=5e-6)


def test_fftfit_classic_runs():
template = vonmises_profile(100, 1024)
_, amp, pha = fftfit_cprof(template)
shift, eshift, snr, esnr, b, errb, ngood = fftfit_classic(template, amp, pha, code="presto")


@pytest.mark.parametrize("n,s", [(32, 0.1),(128,0.3),(1024,0.05)])
def test_fftfit_classic_compare(n, s):
template = vonmises_profile(100, n)
_, amp, pha = fftfit_cprof(template)

profile = pint.profile.shift(template, s) + 1e-3*np.random.randn(len(template))

shift, eshift, snr, esnr, b, errb, ngood = fftfit_classic(profile, amp, pha, code="aarchiba")
shift_p, eshift_p, snr_p, esnr_p, b_p, errb_p, ngood_p = fftfit_classic(profile, amp, pha, code="presto")

assert_allclose_phase(shift, shift_p, atol=1e-2)
#assert_allclose(eshift, eshift_p)
#assert_allclose(snr, snr_p)
#assert_allclose(esnr, esnr_p)
#assert_allclose(b, b_p)
#assert_allclose(errb, errb_p)
#assert_allclose(ngood, ngood_p)

0 comments on commit 8efe495

Please sign in to comment.