From dc15c079e87f32fde1013d5e450d421aa2f420c0 Mon Sep 17 00:00:00 2001 From: David Kaplan Date: Fri, 3 May 2024 10:52:30 -0500 Subject: [PATCH] trying some custom type definitions --- src/pint/__init__.py | 11 +++++++ src/pint/models/astrometry.py | 60 ++++++++++++----------------------- src/pint/simulation.py | 15 +++++---- 3 files changed, 39 insertions(+), 47 deletions(-) diff --git a/src/pint/__init__.py b/src/pint/__init__.py index f2bb62790..c88451186 100644 --- a/src/pint/__init__.py +++ b/src/pint/__init__.py @@ -18,6 +18,9 @@ import numpy as np import pkg_resources from astropy.units import si +from pathlib import Path + +from typing import Union, IO from pint import logging from pint.extern._version import get_versions @@ -115,3 +118,11 @@ def print_info(): """Print the OS version, Python version, PINT version, versions of the dependencies etc.""" print(info_string(detailed=True)) + + +# custom types +# Something that is a Quantity or can behave like one (with units assumed) +quantity_like = Union[float, np.ndarray, u.Quantity] +# Something that is a Time or can behave like one +time_like = Union[float, np.ndarray, u.Quantity, time.Time] +file_like = Union[str, Path, IO] diff --git a/src/pint/models/astrometry.py b/src/pint/models/astrometry.py index a23d616cf..8f5d4ea30 100644 --- a/src/pint/models/astrometry.py +++ b/src/pint/models/astrometry.py @@ -14,7 +14,7 @@ from loguru import logger as log from erfa import ErfaWarning, pmsafe -from pint import ls +from pint import ls, time_like, quantity_like from pint.models.parameter import ( AngleParameter, MJDParameter, @@ -59,9 +59,7 @@ def __init__(self): self.delay_funcs_component += [self.solar_system_geometric_delay] self.register_deriv_funcs(self.d_delay_astrometry_d_PX, "PX") - def ssb_to_psb_xyz_ICRS( - self, epoch: Union[float, u.Quantity, Time] = None - ) -> u.Quantity: + def ssb_to_psb_xyz_ICRS(self, epoch: Optional[time_like] = None) -> u.Quantity: """Returns unit vector(s) from SSB to pulsar system barycenter under ICRS. If epochs (MJD) are given, proper motion is included in the calculation. @@ -83,7 +81,7 @@ def ssb_to_psb_xyz_ICRS( return self.coords_as_ICRS(epoch=epoch).cartesian.xyz.transpose() def ssb_to_psb_xyz_ECL( - self, epoch: Union[float, u.Quantity, Time] = None, ecl: str = None + self, epoch: Optional[time_like] = None, ecl: str = None ) -> u.Quantity: """Returns unit vector(s) from SSB to pulsar system barycenter under Ecliptic coordinates. @@ -243,7 +241,7 @@ def d_delay_astrometry_d_POSEPOCH(self, toas, param="", acc_delay=None): """Calculate the derivative wrt POSEPOCH""" pass - def change_posepoch(self, new_epoch): + def change_posepoch(self, new_epoch: time_like): """Change POSEPOCH to a new value and update the position accordingly. Parameters @@ -350,9 +348,7 @@ def barycentric_radio_freq(self, toas: pint.toa.TOAs) -> u.Quantity: v_dot_L_array = np.sum(tbl["ssb_obs_vel"] * L_hat, axis=1) return tbl["freq"] * (1.0 - v_dot_L_array / const.c) - def get_psr_coords( - self, epoch: Union[float, u.Quantity, Time] = None - ) -> coords.SkyCoord: + def get_psr_coords(self, epoch: Optional[time_like] = None) -> coords.SkyCoord: """Returns pulsar sky coordinates as an astropy ICRS object instance. Parameters @@ -392,9 +388,7 @@ def get_psr_coords( return position_then - def coords_as_ICRS( - self, epoch: Union[float, u.Quantity, Time] = None - ) -> coords.SkyCoord: + def coords_as_ICRS(self, epoch: Optional[time_like] = None) -> coords.SkyCoord: """Return the pulsar's ICRS coordinates as an astropy coordinate object. Parameters @@ -409,7 +403,7 @@ def coords_as_ICRS( return self.get_psr_coords(epoch) def coords_as_ECL( - self, epoch: Union[float, u.Quantity, Time] = None, ecl: str = None + self, epoch: Optional[time_like] = None, ecl: str = None ) -> coords.SkyCoord: """Return the pulsar's ecliptic coordinates as an astropy coordinate object. @@ -434,9 +428,7 @@ def coords_as_ECL( pos_icrs = self.get_psr_coords(epoch=epoch) return pos_icrs.transform_to(PulsarEcliptic(ecl=ecl)) - def coords_as_GAL( - self, epoch: Union[float, u.Quantity, Time] = None - ) -> coords.SkyCoord: + def coords_as_GAL(self, epoch: Optional[time_like] = None) -> coords.SkyCoord: """Return the pulsar's galactic coordinates as an astropy coordinate object. Parameters @@ -459,9 +451,7 @@ def get_params_as_ICRS(self) -> dict: "PMDEC": self.PMDEC.quantity, } - def ssb_to_psb_xyz_ICRS( - self, epoch: Union[float, u.Quantity, Time] = None - ) -> u.Quantity: + def ssb_to_psb_xyz_ICRS(self, epoch: Optional[time_like] = None) -> u.Quantity: """Returns unit vector(s) from SSB to pulsar system barycenter under ICRS. If epochs (MJD) are given, proper motion is included in the calculation. @@ -621,7 +611,7 @@ def d_delay_astrometry_d_PMDEC( # We want to return sec / (mas / yr) return dd_dpmdec.decompose(u.si.bases) / (u.mas / u.year) - def change_posepoch(self, new_epoch: Union[float, u.Quantity, Time]): + def change_posepoch(self, new_epoch: time_like): """Change POSEPOCH to a new value and update the position accordingly. Parameters @@ -643,9 +633,7 @@ def change_posepoch(self, new_epoch: Union[float, u.Quantity, Time]): self.DECJ.value = new_coords.dec self.POSEPOCH.value = new_epoch - def as_ICRS( - self, epoch: Union[float, u.Quantity, Time] = None - ) -> "AstrometryEquatorial": + def as_ICRS(self, epoch: Optional[time_like] = None) -> "AstrometryEquatorial": """Return pint.models.astrometry.Astrometry object in ICRS frame. Parameters @@ -664,7 +652,7 @@ def as_ICRS( return m def as_ECL( - self, epoch: Union[float, u.Quantity, Time] = None, ecl: str = "IERS2010" + self, epoch: Optional[time_like] = None, ecl: str = "IERS2010" ) -> "AstrometryEcliptic": """Return pint.models.astrometry.Astrometry object in PulsarEcliptic frame. @@ -842,9 +830,7 @@ def barycentric_radio_freq(self, toas: pint.toa.TOAs) -> u.Quantity: v_dot_L_array = np.sum(tbl["ssb_obs_vel_ecl"] * L_hat, axis=1) return tbl["freq"] * (1.0 - v_dot_L_array / const.c) - def get_psr_coords( - self, epoch: Union[float, u.Quantity, Time] = None - ) -> coords.SkyCoord: + def get_psr_coords(self, epoch: Optional[time_like] = None) -> coords.SkyCoord: """Returns pulsar sky coordinates as an astropy ecliptic coordinate instance. Parameters @@ -888,9 +874,7 @@ def get_psr_coords( position_then = position_now.apply_space_motion(new_obstime=newepoch) return remove_dummy_distance(position_then) - def coords_as_ICRS( - self, epoch: Union[float, u.Quantity, Time] = None - ) -> coords.SkyCoord: + def coords_as_ICRS(self, epoch: Optional[time_like] = None) -> coords.SkyCoord: """Return the pulsar's ICRS coordinates as an astropy coordinate object. Parameters @@ -905,9 +889,7 @@ def coords_as_ICRS( pos_ecl = self.get_psr_coords(epoch=epoch) return pos_ecl.transform_to(coords.ICRS) - def coords_as_GAL( - self, epoch: Union[float, u.Quantity, Time] = None - ) -> coords.SkyCoord: + def coords_as_GAL(self, epoch: Optional[time_like] = None) -> coords.SkyCoord: """Return the pulsar's galactic coordinates as an astropy coordinate object. Parameters @@ -923,7 +905,7 @@ def coords_as_GAL( return pos_ecl.transform_to(coords.Galactic) def coords_as_ECL( - self, epoch: Union[float, u.Quantity, Time] = None, ecl: str = None + self, epoch: Optional[time_like] = None, ecl: str = None ) -> coords.SkyCoord: """Return the pulsar's ecliptic coordinates as an astropy coordinate object. @@ -947,7 +929,7 @@ def coords_as_ECL( return pos_ecl def ssb_to_psb_xyz_ECL( - self, epoch: Union[float, u.Quantity, Time] = None, ecl: str = None + self, epoch: Optional[time_like] = None, ecl: str = None ) -> u.Quantity: """Returns unit vector(s) from SSB to pulsar system barycenter under ECL. @@ -1177,7 +1159,7 @@ def print_par(self, format: str = "pint") -> str: result += getattr(self, p).as_parfile_line(format=format) return result - def change_posepoch(self, new_epoch: Union[float, u.Quantity, Time]): + def change_posepoch(self, new_epoch: time_like): """Change POSEPOCH to a new value and update the position accordingly. Parameters @@ -1198,7 +1180,7 @@ def change_posepoch(self, new_epoch: Union[float, u.Quantity, Time]): self.POSEPOCH.value = new_epoch def as_ECL( - self, epoch: Union[float, u.Quantity, Time] = None, ecl: str = "IERS2010" + self, epoch: Optional[time_like] = None, ecl: str = "IERS2010" ) -> "AstrometryEcliptic": """Return pint.models.astrometry.Astrometry object in PulsarEcliptic frame. @@ -1290,9 +1272,7 @@ def as_ECL( return m_ecl - def as_ICRS( - self, epoch: Union[float, u.Quantity, Time] = None - ) -> "AstrometryEquatorial": + def as_ICRS(self, epoch: Optional[time_like] = None) -> "AstrometryEquatorial": """Return pint.models.astrometry.Astrometry object in ICRS frame. Parameters diff --git a/src/pint/simulation.py b/src/pint/simulation.py index 59eabde2f..03d904b2b 100644 --- a/src/pint/simulation.py +++ b/src/pint/simulation.py @@ -11,6 +11,7 @@ from loguru import logger as log from astropy import time +from pint import time_like, quantity_like, file_like import pint.residuals import pint.toa import pint.fitter @@ -221,8 +222,8 @@ def update_fake_dms( def make_fake_toas_uniform( - startMJD: Union[float, u.Quantity, time.Time], - endMJD: Union[float, u.Quantity, time.Time], + startMJD: time_like, + endMJD: time_like, ntoas: int, model: pint.models.timing_model.TimingModel, fuzz: u.Quantity = 0, @@ -363,7 +364,7 @@ def make_fake_toas_uniform( def make_fake_toas_fromMJDs( - MJDs: Union[u.Quantity, time.Time, np.ndarray], + MJDs: time_like, model: pint.models.timing_model.TimingModel, freq: u.Quantity = 1400 * u.MHz, obs: str = "GBT", @@ -498,7 +499,7 @@ def make_fake_toas_fromMJDs( def make_fake_toas_fromtim( - timfile: Union[str, List[str], pathlib.Path], + timfile: Union[file_like, List[str]], model: pint.models.timing_model.TimingModel, add_noise: bool = False, add_correlated_noise: bool = False, @@ -690,12 +691,12 @@ def calculate_random_models( def _get_freqs_and_times( - start: Union[float, u.Quantity, time.Time], - end: Union[float, u.Quantity, time.Time], + start: time_like, + end: time_like, ntoas: int, freqs: u.Quantity, multi_freqs_in_epoch: bool = True, -) -> Tuple[Union[float, u.Quantity, time.Time], np.ndarray]: +) -> Tuple[time_like, np.ndarray]: freqs = np.atleast_1d(freqs) assert ( len(freqs.shape) == 1 and len(freqs) <= ntoas