Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom type definitions #1756

Merged
merged 3 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/pint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import numpy as np
import pkg_resources
from astropy.units import si
from pathlib import Path

from typing import Union, IO

from pint import logging
from pint.extern._version import get_versions
Expand Down Expand Up @@ -115,3 +118,11 @@
def print_info():
"""Print the OS version, Python version, PINT version, versions of the dependencies etc."""
print(info_string(detailed=True))


# custom types
# Something that is a Quantity or can behave like one (with units assumed)
quantity_like = Union[float, np.ndarray, u.Quantity]
# Something that is a Time or can behave like one
time_like = Union[float, np.ndarray, u.Quantity, time.Time]
file_like = Union[str, Path, IO]
60 changes: 20 additions & 40 deletions src/pint/models/astrometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from loguru import logger as log

from erfa import ErfaWarning, pmsafe
from pint import ls
from pint import ls, time_like, quantity_like
from pint.models.parameter import (
AngleParameter,
MJDParameter,
Expand Down Expand Up @@ -59,9 +59,7 @@ def __init__(self):
self.delay_funcs_component += [self.solar_system_geometric_delay]
self.register_deriv_funcs(self.d_delay_astrometry_d_PX, "PX")

def ssb_to_psb_xyz_ICRS(
self, epoch: Union[float, u.Quantity, Time] = None
) -> u.Quantity:
def ssb_to_psb_xyz_ICRS(self, epoch: Optional[time_like] = None) -> u.Quantity:
"""Returns unit vector(s) from SSB to pulsar system barycenter under ICRS.

If epochs (MJD) are given, proper motion is included in the calculation.
Expand All @@ -83,7 +81,7 @@ def ssb_to_psb_xyz_ICRS(
return self.coords_as_ICRS(epoch=epoch).cartesian.xyz.transpose()

def ssb_to_psb_xyz_ECL(
self, epoch: Union[float, u.Quantity, Time] = None, ecl: str = None
self, epoch: Optional[time_like] = None, ecl: str = None
) -> u.Quantity:
"""Returns unit vector(s) from SSB to pulsar system barycenter under Ecliptic coordinates.

Expand Down Expand Up @@ -243,7 +241,7 @@ def d_delay_astrometry_d_POSEPOCH(self, toas, param="", acc_delay=None):
"""Calculate the derivative wrt POSEPOCH"""
pass

def change_posepoch(self, new_epoch):
def change_posepoch(self, new_epoch: time_like):
"""Change POSEPOCH to a new value and update the position accordingly.

Parameters
Expand Down Expand Up @@ -350,9 +348,7 @@ def barycentric_radio_freq(self, toas: pint.toa.TOAs) -> u.Quantity:
v_dot_L_array = np.sum(tbl["ssb_obs_vel"] * L_hat, axis=1)
return tbl["freq"] * (1.0 - v_dot_L_array / const.c)

def get_psr_coords(
self, epoch: Union[float, u.Quantity, Time] = None
) -> coords.SkyCoord:
def get_psr_coords(self, epoch: Optional[time_like] = None) -> coords.SkyCoord:
"""Returns pulsar sky coordinates as an astropy ICRS object instance.

Parameters
Expand Down Expand Up @@ -392,9 +388,7 @@ def get_psr_coords(

return position_then

def coords_as_ICRS(
self, epoch: Union[float, u.Quantity, Time] = None
) -> coords.SkyCoord:
def coords_as_ICRS(self, epoch: Optional[time_like] = None) -> coords.SkyCoord:
"""Return the pulsar's ICRS coordinates as an astropy coordinate object.

Parameters
Expand All @@ -409,7 +403,7 @@ def coords_as_ICRS(
return self.get_psr_coords(epoch)

def coords_as_ECL(
self, epoch: Union[float, u.Quantity, Time] = None, ecl: str = None
self, epoch: Optional[time_like] = None, ecl: str = None
) -> coords.SkyCoord:
"""Return the pulsar's ecliptic coordinates as an astropy coordinate object.

Expand All @@ -434,9 +428,7 @@ def coords_as_ECL(
pos_icrs = self.get_psr_coords(epoch=epoch)
return pos_icrs.transform_to(PulsarEcliptic(ecl=ecl))

def coords_as_GAL(
self, epoch: Union[float, u.Quantity, Time] = None
) -> coords.SkyCoord:
def coords_as_GAL(self, epoch: Optional[time_like] = None) -> coords.SkyCoord:
"""Return the pulsar's galactic coordinates as an astropy coordinate object.

Parameters
Expand All @@ -459,9 +451,7 @@ def get_params_as_ICRS(self) -> dict:
"PMDEC": self.PMDEC.quantity,
}

def ssb_to_psb_xyz_ICRS(
self, epoch: Union[float, u.Quantity, Time] = None
) -> u.Quantity:
def ssb_to_psb_xyz_ICRS(self, epoch: Optional[time_like] = None) -> u.Quantity:
"""Returns unit vector(s) from SSB to pulsar system barycenter under ICRS.

If epochs (MJD) are given, proper motion is included in the calculation.
Expand Down Expand Up @@ -621,7 +611,7 @@ def d_delay_astrometry_d_PMDEC(
# We want to return sec / (mas / yr)
return dd_dpmdec.decompose(u.si.bases) / (u.mas / u.year)

def change_posepoch(self, new_epoch: Union[float, u.Quantity, Time]):
def change_posepoch(self, new_epoch: time_like):
"""Change POSEPOCH to a new value and update the position accordingly.

Parameters
Expand All @@ -643,9 +633,7 @@ def change_posepoch(self, new_epoch: Union[float, u.Quantity, Time]):
self.DECJ.value = new_coords.dec
self.POSEPOCH.value = new_epoch

def as_ICRS(
self, epoch: Union[float, u.Quantity, Time] = None
) -> "AstrometryEquatorial":
def as_ICRS(self, epoch: Optional[time_like] = None) -> "AstrometryEquatorial":
"""Return pint.models.astrometry.Astrometry object in ICRS frame.

Parameters
Expand All @@ -664,7 +652,7 @@ def as_ICRS(
return m

def as_ECL(
self, epoch: Union[float, u.Quantity, Time] = None, ecl: str = "IERS2010"
self, epoch: Optional[time_like] = None, ecl: str = "IERS2010"
) -> "AstrometryEcliptic":
"""Return pint.models.astrometry.Astrometry object in PulsarEcliptic frame.

Expand Down Expand Up @@ -842,9 +830,7 @@ def barycentric_radio_freq(self, toas: pint.toa.TOAs) -> u.Quantity:
v_dot_L_array = np.sum(tbl["ssb_obs_vel_ecl"] * L_hat, axis=1)
return tbl["freq"] * (1.0 - v_dot_L_array / const.c)

def get_psr_coords(
self, epoch: Union[float, u.Quantity, Time] = None
) -> coords.SkyCoord:
def get_psr_coords(self, epoch: Optional[time_like] = None) -> coords.SkyCoord:
"""Returns pulsar sky coordinates as an astropy ecliptic coordinate instance.

Parameters
Expand Down Expand Up @@ -888,9 +874,7 @@ def get_psr_coords(
position_then = position_now.apply_space_motion(new_obstime=newepoch)
return remove_dummy_distance(position_then)

def coords_as_ICRS(
self, epoch: Union[float, u.Quantity, Time] = None
) -> coords.SkyCoord:
def coords_as_ICRS(self, epoch: Optional[time_like] = None) -> coords.SkyCoord:
"""Return the pulsar's ICRS coordinates as an astropy coordinate object.

Parameters
Expand All @@ -905,9 +889,7 @@ def coords_as_ICRS(
pos_ecl = self.get_psr_coords(epoch=epoch)
return pos_ecl.transform_to(coords.ICRS)

def coords_as_GAL(
self, epoch: Union[float, u.Quantity, Time] = None
) -> coords.SkyCoord:
def coords_as_GAL(self, epoch: Optional[time_like] = None) -> coords.SkyCoord:
"""Return the pulsar's galactic coordinates as an astropy coordinate object.

Parameters
Expand All @@ -923,7 +905,7 @@ def coords_as_GAL(
return pos_ecl.transform_to(coords.Galactic)

def coords_as_ECL(
self, epoch: Union[float, u.Quantity, Time] = None, ecl: str = None
self, epoch: Optional[time_like] = None, ecl: str = None
) -> coords.SkyCoord:
"""Return the pulsar's ecliptic coordinates as an astropy coordinate object.

Expand All @@ -947,7 +929,7 @@ def coords_as_ECL(
return pos_ecl

def ssb_to_psb_xyz_ECL(
self, epoch: Union[float, u.Quantity, Time] = None, ecl: str = None
self, epoch: Optional[time_like] = None, ecl: str = None
) -> u.Quantity:
"""Returns unit vector(s) from SSB to pulsar system barycenter under ECL.

Expand Down Expand Up @@ -1177,7 +1159,7 @@ def print_par(self, format: str = "pint") -> str:
result += getattr(self, p).as_parfile_line(format=format)
return result

def change_posepoch(self, new_epoch: Union[float, u.Quantity, Time]):
def change_posepoch(self, new_epoch: time_like):
"""Change POSEPOCH to a new value and update the position accordingly.

Parameters
Expand All @@ -1198,7 +1180,7 @@ def change_posepoch(self, new_epoch: Union[float, u.Quantity, Time]):
self.POSEPOCH.value = new_epoch

def as_ECL(
self, epoch: Union[float, u.Quantity, Time] = None, ecl: str = "IERS2010"
self, epoch: Optional[time_like] = None, ecl: str = "IERS2010"
) -> "AstrometryEcliptic":
"""Return pint.models.astrometry.Astrometry object in PulsarEcliptic frame.

Expand Down Expand Up @@ -1290,9 +1272,7 @@ def as_ECL(

return m_ecl

def as_ICRS(
self, epoch: Union[float, u.Quantity, Time] = None
) -> "AstrometryEquatorial":
def as_ICRS(self, epoch: Optional[time_like] = None) -> "AstrometryEquatorial":
"""Return pint.models.astrometry.Astrometry object in ICRS frame.

Parameters
Expand Down
15 changes: 8 additions & 7 deletions src/pint/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from loguru import logger as log
from astropy import time

from pint import time_like, quantity_like, file_like
import pint.residuals
import pint.toa
import pint.fitter
Expand Down Expand Up @@ -221,8 +222,8 @@ def update_fake_dms(


def make_fake_toas_uniform(
startMJD: Union[float, u.Quantity, time.Time],
endMJD: Union[float, u.Quantity, time.Time],
startMJD: time_like,
endMJD: time_like,
ntoas: int,
model: pint.models.timing_model.TimingModel,
fuzz: u.Quantity = 0,
Expand Down Expand Up @@ -363,7 +364,7 @@ def make_fake_toas_uniform(


def make_fake_toas_fromMJDs(
MJDs: Union[u.Quantity, time.Time, np.ndarray],
MJDs: time_like,
model: pint.models.timing_model.TimingModel,
freq: u.Quantity = 1400 * u.MHz,
obs: str = "GBT",
Expand Down Expand Up @@ -498,7 +499,7 @@ def make_fake_toas_fromMJDs(


def make_fake_toas_fromtim(
timfile: Union[str, List[str], pathlib.Path],
timfile: Union[file_like, List[str]],
model: pint.models.timing_model.TimingModel,
add_noise: bool = False,
add_correlated_noise: bool = False,
Expand Down Expand Up @@ -690,12 +691,12 @@ def calculate_random_models(


def _get_freqs_and_times(
start: Union[float, u.Quantity, time.Time],
end: Union[float, u.Quantity, time.Time],
start: time_like,
end: time_like,
ntoas: int,
freqs: u.Quantity,
multi_freqs_in_epoch: bool = True,
) -> Tuple[Union[float, u.Quantity, time.Time], np.ndarray]:
) -> Tuple[time_like, np.ndarray]:
freqs = np.atleast_1d(freqs)
assert (
len(freqs.shape) == 1 and len(freqs) <= ntoas
Expand Down
Loading