Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce experimental BasisCommonGP2 and FourierBasisCommonGP2 that accept selections #290

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions enterprise/signals/gp_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,153 @@ def get_dm_chi2(self, params, use_mean_dm=False): # 'DM' chi-sqaured
return WidebandTimingModel


# experimental versions of FourierBasisCommonGP and BasisCommonGP2 that support selections
# note that Tspan must be provided to FourierBasisCommonGP2


def FourierBasisCommonGP2(
spectrum,
orf,
coefficients=False,
combine=True,
selection=Selection(selections.no_selection),
components=20,
Tspan=None,
modes=None,
name="common_fourier",
pshift=False,
pseed=None,
):
if coefficients:
raise NotImplementedError("Coefficients are not implemented yet.")

if Tspan is None:
raise ValueError("Please specify Tspan explicitly.")

basis = utils.createfourierdesignmatrix_red(nmodes=components, Tspan=Tspan, modes=modes, pshift=pshift, pseed=pseed)

return BasisCommonGP2(spectrum, basis, orf, combine=combine, selection=selection, name=name)


def BasisCommonGP2(
priorFunction,
basisFunction,
orfFunction,
coefficients=False,
combine=True,
selection=Selection(selections.no_selection),
name="",
):
if coefficients:
raise NotImplementedError("Coefficients are not implemented yet.")

class BasisCommonGP2(signal_base.CommonSignal):
signal_type = "common basis"
signal_name = "common"
signal_id = name

basis_combine = combine

def __init__(self, psr):
super(BasisCommonGP2, self).__init__(psr)
self.name = self.psrname + "_" + self.signal_id
self._do_selection(psr, priorFunction, basisFunction, orfFunction, selection)
self._psrpos = psr.pos

def _do_selection(self, psr, priorfn, basisfn, orffn, selection):
sel = selection(psr)

self._keys = sorted(sel.masks.keys())
self._masks = [sel.masks[key] for key in self._keys]
self._prior, self._bases, self._orf = {}, {}, {}
self._params, self._coefficients = {}, {}

for key, mask in zip(self._keys, self._masks):
pnames = [name, key]
pname = "_".join([n for n in pnames if n])

self._prior[key] = priorfn(pname, psr=psr)
self._bases[key] = basisfn(pname, psr=psr)
self._orf[key] = orffn(pname, psr=psr)

for par in itertools.chain(
self._prior[key]._params.values(),
self._bases[key]._params.values(),
self._orf[key]._params.values(),
):
self._params[par.name] = par

@property
def basis_params(self):
"""Get any varying basis parameters."""
ret = []
for basis in self._bases.values():
ret.extend([pp.name for pp in basis.params])
return ret

@signal_base.cache_call("basis_params", limit=1)
def _construct_basis(self, params={}):
basis, self._labels = {}, {}
for key, mask in zip(self._keys, self._masks):
basis[key], self._labels[key] = self._bases[key](params=params, mask=mask)

nc = sum(F.shape[1] for F in basis.values())
self._basis = np.zeros((len(self._masks[0]), nc))

# TODO: should this be defined here? it will cache phi
self._phi = KernelMatrix(nc)

self._slices = {}
nctot = 0
for key, mask in zip(self._keys, self._masks):
Fmat = basis[key]
nn = Fmat.shape[1]
self._basis[mask, nctot : nn + nctot] = Fmat
self._slices.update({key: slice(nctot, nn + nctot)})
nctot += nn

@property
def delay_params(self):
return []

def get_delay(self, params={}):
return 0

def get_basis(self, params={}):
self._construct_basis(params)

return self._basis

def get_phi(self, params):
self._construct_basis(params)

for key, slc in self._slices.items():
phislc = self._prior[key](self._labels[key], params=params)
orfslc = self._orf[key](self._psrpos, self._psrpos, params=params)

self._phi = self._phi.set(phislc * orfslc, slc)

return self._phi

@classmethod
def get_phicross(cls, signal1, signal2, params):
sl1, sl2 = [sum(slc.stop - slc.start for slc in signal._slices.values()) for signal in [signal1, signal2]]

phic = np.zeros((sl1, sl2))

for key in set(signal1._keys) & set(signal2._keys):
phislc = signal1._prior[key](signal1._labels[key], params=params)
orfslc = signal1._orf[key](signal1._psrpos, signal2._psrpos, params=params)

r1, r2 = [range(signal._slices[key].start, signal._slices[key].stop) for signal in [signal1, signal2]]

phic[r1, r2] = phislc * orfslc

return phic

return BasisCommonGP2


def MarginalizingTimingModel(name="marginalizing_linear_timing_model", use_svd=False, normed=True):
"""Class factory for marginalizing (fast-likelihood) linear timing model signals."""

Expand Down
56 changes: 24 additions & 32 deletions enterprise/signals/signal_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,8 @@
derived from these base classes.
"""
import collections
from collections.abc import Sequence

try:
from collections.abc import Sequence
except:
from collections import Sequence

import itertools
import logging

import numpy as np
Expand Down Expand Up @@ -275,18 +270,15 @@ def __add__(self, other):

@property
def params(self):
ret = set()

for signalcollection in self._signalcollections:
for param in signalcollection.params:
for par in param.params:
ret.add(par)
# return only one parameter with the same name
ret = {
par.name: par
for signalcollection in self._signalcollections
for param in signalcollection.params
for par in param.params
}

return sorted(list(ret), key=lambda par: par.name)

# return sorted({par for signalcollection in self._signalcollections
# for par in signalcollection.params},
# key=lambda par: par.name)
return sorted(ret.values(), key=lambda par: par.name)

@property
def param_names(self):
Expand Down Expand Up @@ -470,7 +462,7 @@ def get_phiinv_byfreq_partition(self, params, logdet=False):

if crossdiag.ndim == 2:
raise NotImplementedError(
"get_phiinv with method='partition' does not " "support dense phi matrices."
"get_phiinv with method='partition' does not support dense or rectangular phi matrices."
)

invert[:, i, j] += crossdiag
Expand Down Expand Up @@ -643,21 +635,23 @@ def get_phi(self, params, cliques=False):
if cliques:
self._setcliques(slices, csdict)

# now iterate over all pairs of common signal instances
pairs = itertools.combinations(csdict.items(), 2)
for cs1, csc1 in csdict.items():
for cs2, csc2 in csdict.items():
if cs1 != cs2:
crossdiag = csclass.get_phicross(cs1, cs2, params)

for (cs1, csc1), (cs2, csc2) in pairs:
crossdiag = csclass.get_phicross(cs1, cs2, params)
block1, idx1 = slices[csc1], csc1._idx[cs1]
block2, idx2 = slices[csc2], csc2._idx[cs2]

block1, idx1 = slices[csc1], csc1._idx[cs1]
block2, idx2 = slices[csc2], csc2._idx[cs2]
if crossdiag.ndim == 1:
Phi[block1, block2][idx1, idx2] += crossdiag
else:
if cliques and crossdiag.shape[0] != crossdiag.shape[1]:
raise NotImplementedError(
"get_phi with cliques=True does not support rectangular phicross matrices"
)

if crossdiag.ndim == 1:
Phi[block1, block2][idx1, idx2] += crossdiag
Phi[block2, block1][idx2, idx1] += crossdiag
else:
Phi[block1, block2][np.ix_(idx1, idx2)] += crossdiag
Phi[block2, block1][np.ix_(idx2, idx1)] += crossdiag
Phi[block1, block2][np.ix_(idx1, idx2)] += crossdiag

return Phi
else:
Expand Down Expand Up @@ -1062,7 +1056,6 @@ def _add_diag(self, other):
return self._binopt(other_diag, "_plus_")

def __add__(self, other):

if isinstance(other, (np.ndarray, ndarray_alt)) and other.ndim == 1:
return self._add_diag(other)
else:
Expand Down Expand Up @@ -1197,7 +1190,6 @@ def _get_logdet(self):
return logdet

def solve(self, other, left_array=None, logdet=False):

if other.ndim not in [1, 2]:
raise TypeError
if left_array is not None:
Expand Down
8 changes: 7 additions & 1 deletion tests/test_gp_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def setUpClass(cls):

# initialize Pulsar class
cls.psr = Pulsar(datadir + "/B1855+09_NANOGrav_9yv1.gls.par", datadir + "/B1855+09_NANOGrav_9yv1.tim")

cls.psr2 = Pulsar(datadir + "/B1937+21_NANOGrav_9yv1.gls.par", datadir + "/B1937+21_NANOGrav_9yv1.tim")

def test_ephemeris(self):
Expand Down Expand Up @@ -392,6 +391,13 @@ def setUpClass(cls):
timing_package="pint",
)

cls.psr2 = Pulsar(
datadir + "/B1937+21_NANOGrav_9yv1.gls.par",
datadir + "/B1937+21_NANOGrav_9yv1.tim",
ephem="DE430",
timing_package="pint",
)

def test_ephemeris(self):
# skipping ephemeris with PINT
pass
70 changes: 69 additions & 1 deletion tests/test_gp_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import scipy.linalg as sl

from enterprise.pulsar import Pulsar
from enterprise.signals import gp_signals, parameter, selections, signal_base, utils
from enterprise.signals import gp_signals, white_signals, parameter, selections, signal_base, utils
from enterprise.signals.selections import Selection
from tests.enterprise_test_data import datadir

Expand Down Expand Up @@ -710,6 +710,70 @@ def test_combine_signals(self):
msg = "Basis matrix shape incorrect size for combined signal."
assert m.get_basis(params).shape == T.shape, msg

def test_gp_common_selection(self):
psr2 = Pulsar(datadir + "/B1937+21_NANOGrav_9yv1.gls.par", datadir + "/B1937+21_NANOGrav_9yv1.tim")

mn = white_signals.MeasurementNoise()

pl = utils.powerlaw(log10_A=parameter.Uniform(-20, -11), gamma=parameter.Uniform(0, 7))
orf = utils.hd_orf()
Tspan = max(self.psr.toas.max(), psr2.toas.max()) - min(self.psr.toas.min(), psr2.toas.min())

prn = gp_signals.FourierBasisGP(pl, components=1, Tspan=Tspan)
rn = gp_signals.FourierBasisCommonGP2(
pl, orf, selection=Selection(selections.by_telescope), components=1, Tspan=Tspan
)

model = mn + prn + rn
pta = signal_base.PTA([model(self.psr), model(psr2)])

telescopes = sorted(np.unique(psr2.telescope))

parnames = [par.name for par in pta.params]

msg = "Per-telescope common-noise parameters not in PTA"
assert all("common_fourier_{}_gamma".format(telescope) in parnames for telescope in telescopes), msg

p0 = parameter.sample(pta.params)

# will throw if there are problems
pta.get_lnlikelihood(params=p0, phiinv_method="sparse")
pta.get_lnprior(params=p0)

# should throw since phiinv_method is not 'sparse'
with self.assertRaises(NotImplementedError):
pta.get_lnlikelihood(params=p0)

msg = "Wrong nonzero element count in Phi matrices"
assert len(pta.pulsarmodels[0].get_phi(p0)) == 2, msg
assert len(pta.pulsarmodels[1].get_phi(p0)) == 6, msg

Phi = pta.get_phi(p0)
assert sum(sum(Phi != 0)) == 12, msg

# determine order of GP components in psr2
b0 = pta.pulsarmodels[1].get_basis()[:, 0] != 0
b1 = pta.pulsarmodels[1].get_basis()[:, 2] != 0
b2 = pta.pulsarmodels[1].get_basis()[:, 4] != 0

# a0 is arecibo/ao since telescopes is sorted
a0 = [np.all(b == (psr2.telescope == telescopes[0])) for b in [b0, b1, b2]]
a1 = [np.all(b == (psr2.telescope == telescopes[1])) for b in [b0, b1, b2]]

msg = "Wrong telescope masks for psr2"
assert sum(a0) == sum(a1) == 1, msg

# check cross-pulsar correlations are in the right place
i = len(pta.pulsarmodels[0].get_phi(p0)) + 2 * a0.index(True)
msg = "Wrong Phi cross terms"
assert Phi[0, i] != 0 and Phi[0, i] == Phi[i, 0], msg
assert Phi[1, i + 1] != 0 and Phi[1, i + 1] == Phi[i + 1, 1], msg

msg = "Discrepant Phi inverse"
assert np.allclose(
pta.get_phiinv(params=p0, method="sparse").toarray(), np.linalg.inv(pta.get_phi(params=p0))
), msg


class TestGPSignalsPint(TestGPSignals):
@classmethod
Expand All @@ -723,3 +787,7 @@ def setUpClass(cls):
ephem="DE430",
timing_package="pint",
)

# won't work because one PSR will have telescope == 'arecibo', the other 'ao'
def test_gp_common_selection(self):
pass