From da07be47dc34a0b9ba613006f0624c99b8b5f3ff Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Tue, 5 Mar 2024 20:31:59 +0000 Subject: [PATCH] Annotating pint.utils --- src/pint/observatory/__init__.py | 4 +- src/pint/utils.py | 208 +++++++++++++++++++------------ 2 files changed, 128 insertions(+), 84 deletions(-) diff --git a/src/pint/observatory/__init__.py b/src/pint/observatory/__init__.py index 8564e2a38..caf3af767 100644 --- a/src/pint/observatory/__init__.py +++ b/src/pint/observatory/__init__.py @@ -577,10 +577,10 @@ def compare_t2_observatories_dat(t2dir: Optional[str] = None) -> Dict[str, List[ with open(filename) as f: for line in interesting_lines(f, comments="#"): try: - x, y, z, full_name, short_name = line.split() + x_str, y_str, z_str, full_name, short_name = line.split() except ValueError as e: raise ValueError(f"unrecognized line '{line}'") from e - x, y, z = float(x), float(y), float(z) + x, y, z = float(x_str), float(y_str), float(z_str) full_name, short_name = full_name.lower(), short_name.lower() topo_obs_entry = textwrap.dedent( f""" diff --git a/src/pint/utils.py b/src/pint/utils.py index 1331ef53b..3b72c175b 100644 --- a/src/pint/utils.py +++ b/src/pint/utils.py @@ -28,6 +28,7 @@ has moved to :mod:`pint.simulation`. """ + import configparser import datetime import getpass @@ -37,11 +38,13 @@ import re import sys import textwrap +import warnings +from collections.abc import Generator, Iterable from contextlib import contextmanager +from copy import deepcopy from pathlib import Path +from typing import IO, Any, Optional, Tuple, Union, List, Dict, Type, Mapping, cast from warnings import warn -from scipy.optimize import minimize -from numdifftools import Hessian import astropy.constants as const import astropy.coordinates as coords @@ -50,16 +53,15 @@ from astropy import constants from astropy.time import Time from loguru import logger as log -from scipy.special import fdtrc +from numdifftools import Hessian from scipy.linalg import cho_factor, cho_solve -from copy import deepcopy -import warnings +from scipy.optimize import minimize +from scipy.special import fdtrc import pint import pint.pulsar_ecliptic from pint.toa_select import TOASelect - __all__ = [ "PINTPrecisionError", "check_longdouble_precision", @@ -114,8 +116,17 @@ "get_unit", ] -COLOR_NAMES = ["black", "red", "green", "yellow", "blue", "magenta", "cyan", "white"] -TEXT_ATTRIBUTES = [ +COLOR_NAMES: list[str] = [ + "black", + "red", + "green", + "yellow", + "blue", + "magenta", + "cyan", + "white", +] +TEXT_ATTRIBUTES: list[str] = [ "normal", "bold", "subdued", @@ -145,7 +156,7 @@ def check_longdouble_precision(): return np.finfo(np.longdouble).eps < 2e-19 -def require_longdouble_precision(): +def require_longdouble_precision() -> None: """Raise an exception if long doubles do not have enough precision. Raises RuntimeError if PINT cannot be run with high precision on this @@ -181,7 +192,13 @@ class PosVel: """ - def __init__(self, pos, vel, obj=None, origin=None): + def __init__( + self, + pos: Union[u.Quantity, np.ndarray], + vel: Union[u.Quantity, np.ndarray], + obj=None, + origin=None, + ): if len(pos) != 3: raise ValueError(f"Position vector has length {len(pos)} instead of 3") self.pos = pos if isinstance(pos, u.Quantity) else np.asarray(pos) @@ -207,13 +224,13 @@ def __init__(self, pos, vel, obj=None, origin=None): self.origin = origin # FIXME: what about dtype compatibility? - def _has_labels(self): + def _has_labels(self) -> bool: return (self.obj is not None) and (self.origin is not None) - def __neg__(self): + def __neg__(self) -> "PosVel": return PosVel(-self.pos, -self.vel, obj=self.origin, origin=self.obj) - def __add__(self, other): + def __add__(self, other: "PosVel") -> "PosVel": obj = None origin = None if self._has_labels() and other._has_labels(): @@ -234,17 +251,17 @@ def __add__(self, other): self.pos + other.pos, self.vel + other.vel, obj=obj, origin=origin ) - def __sub__(self, other): + def __sub__(self, other: "PosVel") -> "PosVel": return self.__add__(other.__neg__()) - def __str__(self): + def __str__(self) -> str: return ( f"PosVel({str(self.pos)}, {str(self.vel)} {self.origin}->{self.obj})" if self._has_labels() else f"PosVel({str(self.pos)}, {str(self.vel)})" ) - def __getitem__(self, k): + def __getitem__(self, k: Union[int, Tuple[int, ...]]) -> "PosVel": """Allow extraction of slices of the contained arrays""" colon = slice(None, None, None) ix = (colon,) + k if isinstance(k, tuple) else (colon, k) @@ -305,7 +322,7 @@ def check_all_partials(f, args, delta=1e-6, atol=1e-4, rtol=1e-4): raise -def has_astropy_unit(x): +def has_astropy_unit(x) -> bool: """Test whether x has a unit attribute containing an astropy unit. This is useful, because different data types can still have units @@ -328,7 +345,7 @@ class PrefixError(ValueError): pass -def split_prefixed_name(name): +def split_prefixed_name(name: str) -> Tuple[str, str, int]: """Split a prefixed name. Parameters @@ -365,17 +382,16 @@ def split_prefixed_name(name): """ for pt in prefix_pattern: - try: - prefix_part, index_part = pt.match(name).groups() + m = pt.match(name) + if m is not None: + prefix_part, index_part = m.groups() break - except AttributeError: - continue else: raise PrefixError(f"Unrecognized prefix name pattern '{name}'.") return prefix_part, index_part, int(index_part) -def taylor_horner(x, coeffs): +def taylor_horner(x: Union[float, np.ndarray, u.Quantity], coeffs): """Evaluate a Taylor series of coefficients at x via the Horner scheme. For example, if we want: 10 + 3*x/1! + 4*x^2/2! + 12*x^3/3! with @@ -444,7 +460,10 @@ def taylor_horner_deriv(x, coeffs, deriv_order=1): @contextmanager -def open_or_use(f, mode="r"): +def open_or_use( + f: Union[str, bytes, Path, IO[Any]], + mode: str = "r", +) -> Generator[IO[Any], None, None]: """Open a filename or use an open file. Specifically, if f is a string, try to use it as an argument to @@ -459,7 +478,7 @@ def open_or_use(f, mode="r"): yield f -def lines_of(f): +def lines_of(f: Union[str, bytes, Path, IO[str]]) -> Generator[str, None, None]: """Iterate over the lines of a file, an open file, or an iterator. If ``f`` is a string, try to open a file of that name. Otherwise @@ -472,7 +491,10 @@ def lines_of(f): yield from fo -def interesting_lines(lines, comments=None): +def interesting_lines( + lines: Iterable[str], + comments: Union[None, str, Iterable[Union[str]]] = None, +) -> Generator[str, None, None]: """Iterate over lines skipping whitespace and comments. Each line has its whitespace stripped and then it is checked whether @@ -480,6 +502,7 @@ def interesting_lines(lines, comments=None): a list of strings. """ + cc: Tuple[str, ...] if comments is None: cc = () elif isinstance(comments, (str, bytes)): @@ -490,8 +513,8 @@ def interesting_lines(lines, comments=None): cs = c.strip() if not cs or not c.startswith(cs): raise ValueError( - "Unable to deal with comments that start with whitespace, " - "but comment string {!r} was requested.".format(c) + f"Unable to deal with comments that start with whitespace, " + f"but comment string {c:!r} was requested." ) for ln in lines: ln = ln.strip() @@ -1077,7 +1100,7 @@ def dmxparse(fitter, save=False): } -def get_prefix_timerange(model, prefixname): +def get_prefix_timerange(model, prefixname: str) -> Tuple[Time, Time]: """Get time range for a prefix quantity like DMX or SWX Parameters @@ -1105,7 +1128,7 @@ def get_prefix_timerange(model, prefixname): return getattr(model, r1).quantity, getattr(model, r2).quantity -def get_prefix_timeranges(model, prefixname): +def get_prefix_timeranges(model, prefixname: str) -> Tuple[np.ndarray, Time, Time]: """Get all time ranges and indices for a prefix quantity like DMX or SWX Parameters @@ -1142,7 +1165,9 @@ def get_prefix_timeranges(model, prefixname): ) -def find_prefix_bytime(model, prefixname, t): +def find_prefix_bytime( + model, prefixname: str, t: Union[Time, u.Quantity] +) -> Union[int, np.ndarray]: """Identify matching index(es) for a prefix parameter like DMX Parameters @@ -1163,11 +1188,14 @@ def find_prefix_bytime(model, prefixname, t): indices, r1, r2 = get_prefix_timeranges(model, prefixname) matches = np.where((t >= r1) & (t < r2))[0] if len(matches) == 1: - matches = int(matches) - return indices[matches] + return int(indices[int(matches)]) + else: + return indices[matches] -def merge_dmx(model, index1, index2, value="mean", frozen=True): +def merge_dmx( + model, index1: int, index2: int, value: str = "mean", frozen: bool = True +) -> int: """Merge two DMX bins Parameters @@ -1197,7 +1225,7 @@ def merge_dmx(model, index1, index2, value="mean", frozen=True): ) if value.lower() == "first": dmx = getattr(model, f"DMX_{index1:04d}").quantity - elif value.lower == "second": + elif value.lower() == "second": dmx = getattr(model, f"DMX_{index2:04d}").quantity elif value.lower() == "mean": dmx = ( @@ -1205,14 +1233,13 @@ def merge_dmx(model, index1, index2, value="mean", frozen=True): + getattr(model, f"DMX_{index2:04d}").quantity ) / 2 # add the new one before we delete previous ones to make sure we have >=1 present - newindex = model.add_DMX_range(tstart, tend, dmx=dmx, frozen=frozen) + newindex: int = model.add_DMX_range(tstart, tend, dmx=dmx, frozen=frozen) model.remove_DMX_range([index1, index2]) return newindex -def split_dmx(model, time): - """ - Split an existing DMX bin at the desired time +def split_dmx(model, time: Time) -> Tuple[int, int]: + """Split an existing DMX bin at the desired time. Parameters ---------- @@ -1234,10 +1261,10 @@ def split_dmx(model, time): dmx_epochs = [f"{x:04d}" for x in DMX_mapping.keys()] DMX_R1 = np.zeros(len(dmx_epochs)) DMX_R2 = np.zeros(len(dmx_epochs)) - for ii, epoch in enumerate(dmx_epochs): - DMX_R1[ii] = getattr(model, "DMXR1_{:}".format(epoch)).value - DMX_R2[ii] = getattr(model, "DMXR2_{:}".format(epoch)).value - ii = np.where((time.mjd > DMX_R1) & (time.mjd < DMX_R2))[0] + for iii, epoch in enumerate(dmx_epochs): + DMX_R1[iii] = getattr(model, "DMXR1_{:}".format(epoch)).value + DMX_R2[iii] = getattr(model, "DMXR2_{:}".format(epoch)).value + ii: np.ndarray = np.where((time.mjd > DMX_R1) & (time.mjd < DMX_R2))[0] if len(ii) == 0: raise ValueError(f"Time {time} not in any DMX bins") ii = ii[0] @@ -1255,9 +1282,8 @@ def split_dmx(model, time): return index, newindex -def split_swx(model, time): - """ - Split an existing SWX bin at the desired time +def split_swx(model, time: Time) -> Tuple[int, int]: + """Split an existing SWX bin at the desired time. Parameters ---------- @@ -1270,7 +1296,6 @@ def split_swx(model, time): Index of existing bin that was split newindex : int Index of new bin that was added - """ try: SWX_mapping = model.get_prefix_mapping("SWX_") @@ -1279,9 +1304,9 @@ def split_swx(model, time): swx_epochs = [f"{x:04d}" for x in SWX_mapping.keys()] SWX_R1 = np.zeros(len(swx_epochs)) SWX_R2 = np.zeros(len(swx_epochs)) - for ii, epoch in enumerate(swx_epochs): - SWX_R1[ii] = getattr(model, "SWXR1_{:}".format(epoch)).value - SWX_R2[ii] = getattr(model, "SWXR2_{:}".format(epoch)).value + for iii, epoch in enumerate(swx_epochs): + SWX_R1[iii] = getattr(model, "SWXR1_{:}".format(epoch)).value + SWX_R2[iii] = getattr(model, "SWXR2_{:}".format(epoch)).value ii = np.where((time.mjd > SWX_R1) & (time.mjd < SWX_R2))[0] if len(ii) == 0: raise ValueError(f"Time {time} not in any SWX bins") @@ -1301,7 +1326,8 @@ def split_swx(model, time): def wavex_setup(model, T_span, freqs=None, n_freqs=None, freeze_params=False): - """ + """Set up a WaveX model. + Set-up a WaveX model based on either an array of user-provided frequencies or the wave number frequency calculation. Sine and Cosine amplitudes are initially set to zero @@ -1725,7 +1751,13 @@ def translate_wavex_to_wave(model): return new_model -def weighted_mean(arrin, weights_in, inputmean=None, calcerr=False, sdev=False): +def weighted_mean( + arrin: np.ndarray, + weights_in: np.ndarray, + inputmean: Optional[float] = None, + calcerr: bool = False, + sdev: bool = False, +) -> Union[Tuple[float, float], Tuple[float, float, float]]: """Compute weighted mean of input values Calculate the weighted mean, error, and optionally standard deviation of @@ -1736,10 +1768,10 @@ def weighted_mean(arrin, weights_in, inputmean=None, calcerr=False, sdev=False): Parameters ---------- arrin : array - Array containing the numbers whose weighted mean is desired. + Array containing the numbers whose weighted mean is desired. weights: array - A set of weights for each element in array. For measurements with - uncertainties, these should be 1/sigma^2. + A set of weights for each element in array. For measurements with + uncertainties, these should be 1/sigma^2. inputmean: float, optional An input mean value, around which the mean is calculated. calcerr : bool, optional @@ -1753,8 +1785,8 @@ def weighted_mean(arrin, weights_in, inputmean=None, calcerr=False, sdev=False): Returns ------- wmean, werr: tuple - A tuple of the weighted mean and error. If sdev=True the - tuple will also contain sdev: wmean,werr,wsdev + A tuple of the weighted mean and error. If sdev=True the + tuple will also contain sdev: wmean,werr,wsdev Notes ----- @@ -1839,7 +1871,7 @@ def ELL1_check( return False -def FTest(chi2_1, dof_1, chi2_2, dof_2): +def FTest(chi2_1: float, dof_1: int, chi2_2: float, dof_2: int) -> float: """Run F-test. Compute an F-test to see if a model with extra parameters is @@ -1876,7 +1908,7 @@ def FTest(chi2_1, dof_1, chi2_2, dof_2): delta_dof = dof_1 - dof_2 new_redchi2 = chi2_2 / dof_2 F = float((delta_chi2 / delta_dof) / new_redchi2) # fdtr doesn't like float128 - return fdtrc(delta_dof, dof_2, F) + return float(fdtrc(delta_dof, dof_2, F)) elif dof_1 == dof_2: log.warning("Models have equal degrees of freedom, cannot perform F-test.") return np.nan @@ -1887,7 +1919,9 @@ def FTest(chi2_1, dof_1, chi2_2, dof_2): return 1.0 -def add_dummy_distance(c, distance=1 * u.kpc): +def add_dummy_distance( + c: coords.SkyCoord, distance: u.Quantity = 1 * u.kpc +) -> coords.SkyCoord: """Adds a dummy distance to a SkyCoord object for applying proper motion Parameters @@ -1959,7 +1993,7 @@ def add_dummy_distance(c, distance=1 * u.kpc): return c -def remove_dummy_distance(c): +def remove_dummy_distance(c: coords.SkyCoord) -> coords.SkyCoord: """Removes a dummy distance from a SkyCoord object after applying proper motion Parameters @@ -2024,7 +2058,9 @@ def remove_dummy_distance(c): return c -def info_string(prefix_string="# ", comment=None, detailed=False): +def info_string( + prefix_string: str = "# ", comment: Optional[str] = None, detailed: bool = False +) -> str: """Returns an informative string about the current state of PINT. Adds: @@ -2132,7 +2168,7 @@ def info_string(prefix_string="# ", comment=None, detailed=False): # user-level git config c = git.GitConfigParser() - username = c.get_value("user", option="name") + f" ({getpass.getuser()})" + username = str(c.get_value("user", option="name")) + f" ({getpass.getuser()})" except (configparser.NoOptionError, configparser.NoSectionError, ImportError): username = getpass.getuser() @@ -2146,13 +2182,14 @@ def info_string(prefix_string="# ", comment=None, detailed=False): } if detailed: - from numpy import __version__ as numpy_version - from scipy import __version__ as scipy_version from astropy import __version__ as astropy_version from erfa import __version__ as erfa_version from jplephem import __version__ as jpleph_version + from loguru import __version__ as loguru_version # type: ignore[attr-defined] from matplotlib import __version__ as matplotlib_version - from loguru import __version__ as loguru_version + from numpy import __version__ as numpy_version + from scipy import __version__ as scipy_version + from pint import __file__ as pint_file info_dict.update( @@ -2205,7 +2242,7 @@ def info_string(prefix_string="# ", comment=None, detailed=False): return s -def list_parameters(class_=None): +def list_parameters(class_=None) -> List[Dict]: """List parameters understood by PINT. Parameters @@ -2265,7 +2302,7 @@ def list_parameters(class_=None): results = {} ct = pint.models.timing_model.Component.component_types.copy() - ct["TimingModel"] = pint.models.timing_model.TimingModel + ct["TimingModel"] = pint.models.timing_model.TimingModel # type: ignore[assignment] for v in ct.values(): for d in list_parameters(v): n = d["name"] @@ -2284,7 +2321,12 @@ def list_parameters(class_=None): return sorted(results.values(), key=lambda d: d["name"]) -def colorize(text, fg_color=None, bg_color=None, attribute=None): +def colorize( + text: str, + fg_color: Optional[str] = None, + bg_color: Optional[str] = None, + attribute: Optional[str] = None, +) -> str: """Colorizes a string (including unicode strings) for printing on the terminal For an example of usage, as well as a demonstration as to what the @@ -2311,9 +2353,11 @@ def colorize(text, fg_color=None, bg_color=None, attribute=None): The colorized string using the defined text attribute. """ COLOR_FORMAT = "\033[%dm\033[%d;%dm%s\033[0m" - FOREGROUND = dict(zip(COLOR_NAMES, list(range(30, 38)))) - BACKGROUND = dict(zip(COLOR_NAMES, list(range(40, 48)))) - ATTRIBUTE = dict(zip(TEXT_ATTRIBUTES, [0, 1, 2, 3, 4, 5, 7, 8])) + FOREGROUND: Dict[Optional[str], int] = dict(zip(COLOR_NAMES, list(range(30, 38)))) + BACKGROUND: Dict[Optional[str], int] = dict(zip(COLOR_NAMES, list(range(40, 48)))) + ATTRIBUTE: Dict[Optional[str], int] = dict( + zip(TEXT_ATTRIBUTES, [0, 1, 2, 3, 4, 5, 7, 8]) + ) fg = FOREGROUND.get(fg_color, 39) bg = BACKGROUND.get(bg_color, 49) att = ATTRIBUTE.get(attribute, 0) @@ -2332,7 +2376,7 @@ def print_color_examples(): print("") -def group_iterator(items): +def group_iterator(items: np.ndarray) -> Generator[Tuple[Any, np.ndarray], None, None]: """An iterator to step over identical items in a :class:`numpy.ndarray` Example @@ -2349,7 +2393,7 @@ def group_iterator(items): yield item, np.where(items == item)[0] -def compute_hash(filename): +def compute_hash(filename: Union[str, Path, IO[bytes]]) -> bytes: """Compute a unique hash of a file. This is designed to keep around to detect changes, not to be @@ -2378,9 +2422,10 @@ def compute_hash(filename): return h.digest() -def get_conjunction(coord, t0, precision="low", ecl="IERS2010"): - """ - Find first time of Solar conjuction after t0 and approximate elongation at conjunction +def get_conjunction( + coord: coords.SkyCoord, t0: Time, precision: str = "low", ecl: str = "IERS2010" +) -> Tuple[Time, u.Quantity]: + """Find first time of Solar conjuction after t0 and approximate elongation at conjunction. Offers a low-precision version (based on analytic expression of Solar longitude) Or a higher-precision version (based on interpolating :func:`astropy.coordinates.get_sun`) @@ -2445,9 +2490,8 @@ def get_conjunction(coord, t0, precision="low", ecl="IERS2010"): return conjunction, csun.separation(coord) -def divide_times(t, t0, offset=0.5): - """ - Divide input times into years relative to t0 +def divide_times(t: Time, t0: Time, offset: float = 0.5) -> np.ndarray: + """Divide input times into years relative to t0. Years are centered around the requested offset value @@ -2479,7 +2523,7 @@ def divide_times(t, t0, offset=0.5): """ dt = t - t0 values = (dt.to(u.yr).value + offset) // 1 - return np.digitize(values, np.unique(values), right=True) + return cast(np.ndarray, np.digitize(values, np.unique(values), right=True)) def convert_dispersion_measure(dm, dmconst=None): @@ -2735,8 +2779,8 @@ def woodbury_dot(Ndiag, U, Phidiag, x, y): def _get_wx2pl_lnlike(model, component_name, ignore_fyr=True): - from pint.models.noise_model import powerlaw from pint import DMconst + from pint.models.noise_model import powerlaw assert component_name in ["WaveX", "DMWaveX"] prefix = "WX" if component_name == "WaveX" else "DMWX"