Skip to content

Commit

Permalink
Final changes for Jeremy
Browse files Browse the repository at this point in the history
  • Loading branch information
GourlieK committed Mar 12, 2024
1 parent 7f12540 commit d6deb3b
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions hasasia/sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import scipy.linalg as sl
import os, pickle
from astropy import units as u
import jax.numpy as jnp
import jax.scipy as jsc

import hasasia
from .utils import create_design_matrix
Expand Down Expand Up @@ -62,14 +64,14 @@ def R_matrix(designmatrix, N):
n,m = M.shape
L = np.linalg.cholesky(N)
Linv = np.linalg.inv(L)
U,s,_ = np.linalg.svd(np.matmul(Linv,M), full_matrices=True)
U,s,_ = np.linalg.svd(jnp.matmul(Linv,M), full_matrices=True)
Id = np.eye(M.shape[0])
S = np.zeros_like(M)
S[:m,:m] = np.diag(s)
inner = np.linalg.inv(np.matmul(S.T,S))
outer = np.matmul(S,np.matmul(inner,S.T))
inner = np.linalg.inv(jnp.matmul(S.T,S))
outer = jnp.matmul(S,jnp.matmul(inner,S.T))

return Id - np.matmul(L,np.matmul(np.matmul(U,outer),np.matmul(U.T,Linv)))
return Id - jnp.matmul(L,jnp.matmul(jnp.matmul(U,outer),jnp.matmul(U.T,Linv)))

def G_matrix(designmatrix):
"""
Expand Down Expand Up @@ -169,7 +171,7 @@ def get_Tf(designmatrix, toas, N=None, nf=200, fmin=None, fmax=2e-7,
m = G.shape[1]
Gtilde = np.zeros((ff.size,G.shape[1]),dtype='complex128')
Gtilde = np.dot(np.exp(1j*2*np.pi*ff[:,np.newaxis]*toas),G)
Tmat = np.matmul(np.conjugate(Gtilde),Gtilde.T)/N_TOA
Tmat = jnp.matmul(np.conjugate(Gtilde),Gtilde.T)/N_TOA
if twofreqs:
Tmat = np.real(Tmat)
else:
Expand Down Expand Up @@ -261,10 +263,14 @@ def get_NcalInv(psr, nf=200, fmin=None, fmax=2e-7, freqs=None,
Gtilde = np.dot(np.exp(1j*2*np.pi*ff[:,np.newaxis]*toas),G)
# N_freq x N_TOA-N_par

Ncal = np.matmul(G.T,np.matmul(psr.N,G)) #N_TOA-N_par x N_TOA-N_par
NcalInv = np.linalg.inv(Ncal) #N_TOA-N_par x N_TOA-N_par

TfN = np.matmul(np.conjugate(Gtilde),np.matmul(NcalInv,Gtilde.T)) / 2
L = jsc.linalg.cholesky(psr.N)
A = jnp.matmul(L,G)
del L
Ncal = jnp.matmul(A.T,A)
del A
NcalInv = jnp.linalg.inv(Ncal)

TfN = jnp.matmul(np.conjugate(Gtilde),jnp.matmul(NcalInv,Gtilde.T)) / 2
if return_Gtilde_Ncal:
return np.real(TfN), Gtilde, Ncal
elif full_matrix:
Expand Down Expand Up @@ -810,7 +816,7 @@ def get_NcalInvIJ(psrs, A_GWB, freqs, full_matrix=False,
# C_h = sl.block_diag(*[corr_from_psd(freqs=freqs, psd=psd,
# toas=p.toas, fast=True) for p in psrs])
C = C_n + C_h
Ncal = np.matmul(G.T, np.matmul(C, G)) #N_TOA-N_par x N_TOA-N_par
Ncal = jnp.matmul(G.T, jnp.matmul(C, G)) #N_TOA-N_par x N_TOA-N_par
NcalInv = np.linalg.inv(Ncal) #N_TOA-N_par x N_TOA-N_par

TfN = NcalInv#np.matmul(G, np.matmul(NcalInv, G.T))
Expand Down Expand Up @@ -1066,7 +1072,7 @@ def corr_from_psd(freqs, psd, toas, fast=True):
df = np.diff(freqs)
df = np.append(df,df[-1])
tm = np.sqrt(psd*df)*np.exp(1j*2*np.pi*freqs*toas[:,np.newaxis])
integrand = np.matmul(tm, np.conjugate(tm.T))
integrand = jnp.matmul(tm, np.conjugate(tm.T))
return np.real(integrand)
else: #Makes much larger arrays, but uses np.trapz
t1, t2 = np.meshgrid(toas, toas, indexing='ij')
Expand Down Expand Up @@ -1107,7 +1113,7 @@ def corr_from_psdIJ(freqs, psd, toasI, toasJ, fast=True):
df = np.append(df,df[-1])
tmI = np.sqrt(psd*df)*np.exp(1j*2*np.pi*freqs*toasI[:,np.newaxis])
tmJ = np.sqrt(psd*df)*np.exp(1j*2*np.pi*freqs*toasJ[:,np.newaxis])
integrand = np.matmul(tmI, np.conjugate(tmJ.T))
integrand = jnp.matmul(tmI, np.conjugate(tmJ.T))
return np.real(integrand)
else: #Makes much larger arrays, but uses np.trapz
t1, t2 = np.meshgrid(toasI, toasJ, indexing='ij')
Expand Down

0 comments on commit d6deb3b

Please sign in to comment.