Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Astrometry faster & avoid long double operations where possible #1743

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG-unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ the released changes.

## Unreleased
### Changed
- Avoided unnecessary creation of `SkyCoord` objects in `AstrometryEquatorial`.
- Avoided unnecessary array slices in `SolarSystemShapiro`
### Added
- `tdbfloat` column in the TOAs table (avoid long double operations as much as possible)
### Fixed
### Removed
19 changes: 12 additions & 7 deletions src/pint/models/astrometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def sun_angle(
osv = tbl["obs_sun_pos"].quantity.copy()
else:
osv = -tbl["ssb_obs_pos"].quantity.copy()
psr_vec = self.ssb_to_psb_xyz_ICRS(epoch=tbl["tdbld"])
psr_vec = self.ssb_to_psb_xyz_ICRS(epoch=tbl["tdbfloat"])
r = (osv**2).sum(axis=1) ** 0.5
osv /= r[:, None]
cos = (osv * psr_vec).sum(axis=1)
Expand All @@ -159,7 +159,7 @@ def solar_system_geometric_delay(
# c selects the non-barycentric TOAs that need actual calculation
c = np.logical_and.reduce(tbl["ssb_obs_pos"] != 0, axis=1)
if np.any(c):
L_hat = self.ssb_to_psb_xyz_ICRS(epoch=tbl["tdbld"][c].astype(np.float64))
L_hat = self.ssb_to_psb_xyz_ICRS(epoch=tbl["tdbfloat"][c])
re_dot_L = np.sum(tbl["ssb_obs_pos"][c] * L_hat, axis=1)
delay[c] = -re_dot_L.to(ls).value
if self.PX.value != 0.0:
Expand All @@ -179,10 +179,10 @@ def get_d_delay_quantities(self, toas: pint.toa.TOAs) -> dict:
# TODO: Should delay not have units of u.second?
delay = self._parent.delay(toas)

# TODO: tbl['tdbld'].quantity should have units of u.day
# TODO: tbl['tdbfloat'].quantity should have units of u.day
# NOTE: Do we need to include the delay here?
tbl = toas.table
rd = {"epoch": tbl["tdbld"].quantity * u.day}
rd = {"epoch": tbl["tdbfloat"].quantity * u.day}
# Distance from SSB to observatory, and from SSB to psr
ssb_obs = tbl["ssb_obs_pos"].quantity
ssb_psr = self.ssb_to_psb_xyz_ICRS(epoch=np.array(rd["epoch"]))
Expand Down Expand Up @@ -346,7 +346,7 @@ def print_par(self, format: str = "pint") -> str:
def barycentric_radio_freq(self, toas: pint.toa.TOAs) -> u.Quantity:
"""Return radio frequencies (MHz) of the toas corrected for Earth motion"""
tbl = toas.table
L_hat = self.ssb_to_psb_xyz_ICRS(epoch=tbl["tdbld"].astype(np.float64))
L_hat = self.ssb_to_psb_xyz_ICRS(epoch=tbl["tdbfloat"])
v_dot_L_array = np.sum(tbl["ssb_obs_vel"] * L_hat, axis=1)
return tbl["freq"] * (1.0 - v_dot_L_array / const.c)

Expand Down Expand Up @@ -485,7 +485,9 @@ def ssb_to_psb_xyz_ICRS(
# does, which is to use https://github.com/liberfa/erfa/blob/master/src/starpm.c
# and then just use the relevant pieces of that
if epoch is None or (self.PMRA.quantity == 0 and self.PMDEC.quantity == 0):
return self.coords_as_ICRS(epoch=epoch).cartesian.xyz.transpose()
ra, dec = self.RAJ.quantity, self.DECJ.quantity
return self.xyz_from_radec(ra, dec)
# return self.coords_as_ICRS(epoch=epoch).cartesian.xyz.transpose()

if isinstance(epoch, Time):
jd1 = epoch.jd1
Expand Down Expand Up @@ -518,6 +520,9 @@ def ssb_to_psb_xyz_ICRS(
)
# ra,dec now in radians
ra, dec = starpmout[0], starpmout[1]
return self.xyz_from_radec(ra, dec)

def xyz_from_radec(self, ra, dec):
x = np.cos(ra) * np.cos(dec)
y = np.sin(ra) * np.cos(dec)
z = np.sin(dec)
Expand Down Expand Up @@ -833,7 +838,7 @@ def barycentric_radio_freq(self, toas: pint.toa.TOAs) -> u.Quantity:
obliquity = OBL[self.ECL.value]
toas.add_vel_ecl(obliquity)
tbl = toas.table
L_hat = self.ssb_to_psb_xyz_ECL(epoch=tbl["tdbld"].astype(np.float64))
L_hat = self.ssb_to_psb_xyz_ECL(epoch=tbl["tdbfloat"])
v_dot_L_array = np.sum(tbl["ssb_obs_vel_ecl"] * L_hat, axis=1)
return tbl["freq"] * (1.0 - v_dot_L_array / const.c)

Expand Down
4 changes: 2 additions & 2 deletions src/pint/models/dispersion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def base_dm(self, toas):
f"DMEPOCH not set but some derivatives are not zero: {dm_terms}"
)
else:
dt = (toas["tdbld"] - DMEPOCH) * u.day
dt = (toas["tdbfloat"] - DMEPOCH) * u.day
dt_value = dt.to_value(u.yr)
else:
dt_value = np.zeros(len(toas), dtype=np.longdouble)
Expand Down Expand Up @@ -266,7 +266,7 @@ def d_dm_d_DMs(
DMEPOCH = 0
else:
DMEPOCH = self.DMEPOCH.value
dt = (toas["tdbld"] - DMEPOCH) * u.day
dt = (toas["tdbfloat"] - DMEPOCH) * u.day
dt_value = (dt.to(u.yr)).value
return taylor_horner(dt_value, dm_terms) * (self.DM.units / par.units)

Expand Down
7 changes: 4 additions & 3 deletions src/pint/models/dmwavex.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""DM variations expressed as a sum of sinusoids."""

import astropy.units as u
import numpy as np
from loguru import logger as log
Expand Down Expand Up @@ -350,7 +351,7 @@ def dmwavex_dm(self, toas):
dmwave_sins = self.get_prefix_mapping_component("DMWXSIN_")
dmwave_cos = self.get_prefix_mapping_component("DMWXCOS_")

base_phase = toas.table["tdbld"].data * u.d - self.DMWXEPOCH.value * u.d
base_phase = toas.table["tdbfloat"].data * u.d - self.DMWXEPOCH.value * u.d
for idx, param in dmwave_freqs.items():
freq = getattr(self, param).quantity
dmwxsin = getattr(self, dmwave_sins[idx]).quantity
Expand All @@ -365,15 +366,15 @@ def dmwavex_delay(self, toas, acc_delay=None):
def d_dm_d_DMWXSIN(self, toas, param, acc_delay=None):
par = getattr(self, param)
freq = getattr(self, f"DMWXFREQ_{int(par.index):04d}").quantity
base_phase = toas.table["tdbld"].data * u.d - self.DMWXEPOCH.value * u.d
base_phase = toas.table["tdbfloat"].data * u.d - self.DMWXEPOCH.value * u.d
arg = 2.0 * np.pi * freq * base_phase
deriv = np.sin(arg.value)
return deriv * dmu / par.units

def d_dm_d_DMWXCOS(self, toas, param, acc_delay=None):
par = getattr(self, param)
freq = getattr(self, f"DMWXFREQ_{int(par.index):04d}").quantity
base_phase = toas.table["tdbld"].data * u.d - self.DMWXEPOCH.value * u.d
base_phase = toas.table["tdbfloat"].data * u.d - self.DMWXEPOCH.value * u.d
arg = 2.0 * np.pi * freq * base_phase
deriv = np.cos(arg.value)
return deriv * dmu / par.units
5 changes: 3 additions & 2 deletions src/pint/models/glitch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Pulsar timing glitches."""

import astropy.units as u
import numpy as np

Expand Down Expand Up @@ -199,7 +200,7 @@ def glitch_phase(self, toas, delay):
dF0 = getattr(self, "GLF0_%d" % idx).quantity
dF1 = getattr(self, "GLF1_%d" % idx).quantity
dF2 = getattr(self, "GLF2_%d" % idx).quantity
dt = (tbl["tdbld"] - eph) * u.day - delay
dt = (tbl["tdbfloat"] - eph) * u.day - delay
dt = dt.to(u.second)
affected = dt > 0.0 # TOAs affected by glitch
# decay term
Expand Down Expand Up @@ -228,7 +229,7 @@ def deriv_prep(self, toas, param, delay):
tbl = toas.table
p, ids, idv = split_prefixed_name(param)
eph = getattr(self, f"GLEP_{ids}").value
dt = (tbl["tdbld"] - eph) * u.day - delay
dt = (tbl["tdbfloat"] - eph) * u.day - delay
dt = dt.to(u.second)
affected = np.where(dt > 0.0)[0]
return tbl, p, ids, idv, dt, affected
Expand Down
3 changes: 2 additions & 1 deletion src/pint/models/ifunc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tabulated extra delays."""

import astropy.units as u
import numpy as np

Expand Down Expand Up @@ -110,7 +111,7 @@ def ifunc_phase(self, toas, delays):
# the MJDs(x) and offsets (y) of the interpolation points
x, y = np.asarray([t.quantity for t in terms]).T
# form barycentered times
ts = toas.table["tdbld"] - delays.to(u.day).value
ts = toas.table["tdbfloat"] - delays.to(u.day).value
times = np.zeros(len(ts))

# Determine what type of interpolation we are doing.
Expand Down
12 changes: 6 additions & 6 deletions src/pint/models/noise_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def get_noise_basis(self, toas):
A quantization matrix maps TOAs to observing epochs.
"""
tbl = toas.table
t = (tbl["tdbld"].quantity * u.day).to(u.s).value
t = (tbl["tdbfloat"].quantity * u.day).to(u.s).value
ecorrs = self.get_ecorrs()
umats = []
for ec in ecorrs:
Expand All @@ -408,7 +408,7 @@ def get_noise_weights(self, toas, nweights=None):
"""
ecorrs = self.get_ecorrs()
if nweights is None:
ts = (toas.table["tdbld"].quantity * u.day).to(u.s).value
ts = (toas.table["tdbfloat"].quantity * u.day).to(u.s).value
nweights = [
get_ecorr_nweights(ts[ec.select_toa_mask(toas)]) for ec in ecorrs
]
Expand Down Expand Up @@ -508,7 +508,7 @@ def get_noise_basis(self, toas):
See the documentation for pl_dm_basis_weight_pair function for details."""

tbl = toas.table
t = (tbl["tdbld"].quantity * u.day).to(u.s).value
t = (tbl["tdbfloat"].quantity * u.day).to(u.s).value
freqs = self._parent.barycentric_radio_freq(toas).to(u.MHz)
fref = 1400 * u.MHz
D = (fref.value / freqs.value) ** 2
Expand All @@ -522,7 +522,7 @@ def get_noise_weights(self, toas):
See the documentation for pl_dm_basis_weight_pair for details."""

tbl = toas.table
t = (tbl["tdbld"].quantity * u.day).to(u.s).value
t = (tbl["tdbfloat"].quantity * u.day).to(u.s).value
amp, gam, nf = self.get_pl_vals()
Ffreqs = get_rednoise_freqs(t, nf)
return powerlaw(Ffreqs, amp, gam) * Ffreqs[0]
Expand Down Expand Up @@ -639,7 +639,7 @@ def get_noise_basis(self, toas):
See the documentation for pl_rn_basis_weight_pair function for details."""

tbl = toas.table
t = (tbl["tdbld"].quantity * u.day).to(u.s).value
t = (tbl["tdbfloat"].quantity * u.day).to(u.s).value
nf = self.get_pl_vals()[2]
return create_fourier_design_matrix(t, nf)

Expand All @@ -649,7 +649,7 @@ def get_noise_weights(self, toas):
See the documentation for pl_rn_basis_weight_pair for details."""

tbl = toas.table
t = (tbl["tdbld"].quantity * u.day).to(u.s).value
t = (tbl["tdbfloat"].quantity * u.day).to(u.s).value
amp, gam, nf = self.get_pl_vals()
Ffreqs = get_rednoise_freqs(t, nf)
return powerlaw(Ffreqs, amp, gam) * Ffreqs[0]
Expand Down
3 changes: 2 additions & 1 deletion src/pint/models/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
parameters PINT understands.

"""

import numbers
from warnings import warn

Expand Down Expand Up @@ -577,7 +578,7 @@ def from_parfile_line(self, line):
return True

def value_as_latex(self):
return f"${self.as_ufloat():.1uSL}$" if not self.frozen else f"{self.value:f}"
return f"{self.value:f}" if self.frozen else f"${self.as_ufloat():.1uSL}$"

def as_latex(self):
try:
Expand Down
5 changes: 3 additions & 2 deletions src/pint/models/piecewise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Pulsar timing piecewise spin-down solution."""

import astropy.units as u
import numpy as np

Expand Down Expand Up @@ -176,9 +177,9 @@ def get_dt_and_affected(self, toas, delay, glepnm):
idx = glep.index
start = getattr(self, "PWSTART_%d" % idx).value
stop = getattr(self, "PWSTOP_%d" % idx).value
affected = (tbl["tdbld"] >= start) & (tbl["tdbld"] < stop)
affected = (tbl["tdbfloat"] >= start) & (tbl["tdbfloat"] < stop)
phsepoch_ld = glep.quantity.tdb.mjd_long
dt = (tbl["tdbld"][affected] - phsepoch_ld) * u.day - delay[affected]
dt = (tbl["tdbfloat"][affected] - phsepoch_ld) * u.day - delay[affected]
return dt, affected

def piecewise_phase(self, toas, delay):
Expand Down
13 changes: 6 additions & 7 deletions src/pint/models/pulsar_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
as PINT timing models.
"""


import astropy.units as u
import contextlib
import numpy as np
Expand Down Expand Up @@ -310,11 +309,11 @@ def check_required_params(self, required_params):
method_name = f"{p.lower()}_func"
try:
par_method = getattr(self.binary_instance, method_name)
except AttributeError:
except AttributeError as e:
raise MissingParameter(
self.binary_model_name,
f"{p} is required for '{self.binary_model_name}'.",
)
) from e
par_method()

# With new parameter class set up, do we need this?
Expand Down Expand Up @@ -377,7 +376,7 @@ def update_binary_object(self, toas, acc_delay=None):
# it's already in ICRS
updates["obs_pos"] = tbl["ssb_obs_pos"].quantity
updates["psr_pos"] = self._parent.ssb_to_psb_xyz_ICRS(
epoch=tbl["tdbld"].astype(np.float64)
epoch=tbl["tdbfloat"]
)
elif "AstrometryEcliptic" in self._parent.components:
# convert from ICRS to ECL
Expand All @@ -390,7 +389,7 @@ def update_binary_object(self, toas, acc_delay=None):
PulsarEcliptic(ecl=self._parent.ECL.value)
).cartesian.xyz.transpose()
updates["psr_pos"] = self._parent.ssb_to_psb_xyz_ECL(
epoch=tbl["tdbld"].astype(np.float64), ecl=self._parent.ECL.value
epoch=tbl["tdbfloat"], ecl=self._parent.ECL.value
)
for par in self.binary_instance.binary_params:
if par in self.binary_instance.param_aliases.keys():
Expand All @@ -402,13 +401,13 @@ def update_binary_object(self, toas, acc_delay=None):
if hasattr(self._parent, par) or set(alias).intersection(self.params):
try:
pint_bin_name = self._parent.match_param_aliases(par)
except UnknownParameter:
except UnknownParameter as e:
if par in self.internal_params:
pint_bin_name = par
else:
raise UnknownParameter(
f"Unable to find {par} in the parent model"
)
) from e
binObjpar = getattr(self._parent, pint_bin_name)

# make sure we aren't passing along derived parameters to the binary instance
Expand Down
9 changes: 4 additions & 5 deletions src/pint/models/solar_system_shapiro.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,19 @@ def solar_system_shapiro_delay(self, toas, acc_delay=None):
tbl = toas.table
delay = numpy.zeros(len(tbl))
for key, grp in toas.get_obs_groups():
tbl_grp = tbl[grp]
if key.lower() == "barycenter":
log.debug("Skipping Shapiro delay for Barycentric TOAs")
continue
psr_dir = self._parent.ssb_to_psb_xyz_ICRS(
epoch=tbl[grp]["tdbld"].astype(numpy.float64)
)
psr_dir = self._parent.ssb_to_psb_xyz_ICRS(epoch=tbl_grp["tdbfloat"])
delay[grp] += self.ss_obj_shapiro_delay(
tbl[grp]["obs_sun_pos"], psr_dir, self._ss_mass_sec["sun"]
tbl_grp["obs_sun_pos"], psr_dir, self._ss_mass_sec["sun"]
)
try:
if self.PLANET_SHAPIRO.value:
for pl in ("jupiter", "saturn", "venus", "uranus", "neptune"):
delay[grp] += self.ss_obj_shapiro_delay(
tbl[grp][f"obs_{pl}_pos"],
tbl_grp[f"obs_{pl}_pos"],
psr_dir,
self._ss_mass_sec[pl],
)
Expand Down
Loading
Loading