Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use cached_property and types #1718

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 106 additions & 63 deletions src/pint/observatory/__init__.py
Original file line number Diff line number Diff line change
@@ -21,22 +21,25 @@
necessary.
"""

from copy import deepcopy
import os
import textwrap
from collections import defaultdict
from collections.abc import Callable
from copy import deepcopy
from io import StringIO
from pathlib import Path
from typing import Optional, Union, List, Dict, Literal

import astropy.coordinates
import astropy.time
import astropy.units as u
import numpy as np
from astropy.coordinates import EarthLocation
from loguru import logger as log

from pint.config import runtimefile
from pint.pulsar_mjd import Time
from pint.utils import interesting_lines
from pint.utils import interesting_lines, PosVel

# Include any files that define observatories here. This will start
# with the standard distribution files, then will read any system- or
@@ -87,7 +90,7 @@ class ClockCorrectionOutOfRange(ClockCorrectionError):
_bipm_clock_versions = {}


def _load_gps_clock():
def _load_gps_clock() -> None:
global _gps_clock
if _gps_clock is None:
log.info("Loading global GPS clock file")
@@ -97,7 +100,7 @@ def _load_gps_clock():
)


def _load_bipm_clock(bipm_version):
def _load_bipm_clock(bipm_version: str) -> None:
bipm_version = bipm_version.lower()
if bipm_version not in _bipm_clock_versions:
try:
@@ -136,34 +139,43 @@ class Observatory:
position.
"""

fullname: str
"""Full human-readable name of the observatory."""
include_gps: bool
"""Whether to include GPS clock corrections."""
include_bipm: bool
"""Whether to include BIPM clock corrections."""
bipm_version: str
"""Version of the BIPM clock file to use."""

# This is a dict containing all defined Observatory instances,
# keyed on standard observatory name.
_registry = {}
_registry: Dict[str, "Observatory"] = {}

# This is a dict mapping any defined aliases to the corresponding
# standard name.
_alias_map = {}
_alias_map: Dict[str, str] = {}

def __init__(
self,
name,
fullname=None,
aliases=None,
include_gps=True,
include_bipm=True,
bipm_version=bipm_default,
overwrite=False,
name: str,
fullname: Optional[str] = None,
aliases: Optional[List[str]] = None,
include_gps: bool = True,
include_bipm: bool = True,
bipm_version: str = bipm_default,
overwrite: bool = False,
):
self._name = name.lower()
self._aliases = (
self._name: str = name.lower()
self._aliases: List[str] = (
list(set(map(str.lower, aliases))) if aliases is not None else []
)
if aliases is not None:
Observatory._add_aliases(self, aliases)
self.fullname = fullname if fullname is not None else name
self.include_gps = include_gps
self.include_bipm = include_bipm
self.bipm_version = bipm_version
self.fullname: str = fullname if fullname is not None else name
self.include_gps: bool = include_gps
self.include_bipm: bool = include_bipm
self.bipm_version: str = bipm_version

if name.lower() in Observatory._registry:
if not overwrite:
@@ -175,16 +187,18 @@ def __init__(
Observatory._register(self, name)

@classmethod
def _register(cls, obs, name):
"""Add an observatory to the registry using the specified name
(which will be converted to lower case). If an existing observatory
def _register(cls, obs: "Observatory", name: str) -> None:
"""Add an observatory to the registry using the specified name (which will be converted to lower case).
If an existing observatory
of the same name exists, it will be replaced with the new one.
The Observatory instance's name attribute will be updated for
consistency."""
consistency.
"""
cls._registry[name.lower()] = obs

@classmethod
def _add_aliases(cls, obs, aliases):
def _add_aliases(cls, obs: "Observatory", aliases: List[str]) -> None:
"""Add aliases for the specified Observatory. Aliases
should be given as a list. If any of the new aliases are already in
use, they will be replaced. Aliases are not checked against the
@@ -196,14 +210,17 @@ def _add_aliases(cls, obs, aliases):
cls._alias_map[a.lower()] = obs.name

@staticmethod
def gps_correction(t, limits="warn"):
def gps_correction(t: astropy.time.Time, limits: str = "warn") -> u.Quantity:
"""Compute the GPS clock corrections for times t."""
log.info("Applying GPS to UTC clock correction (~few nanoseconds)")
_load_gps_clock()
assert _gps_clock is not None
return _gps_clock.evaluate(t, limits=limits)

@staticmethod
def bipm_correction(t, bipm_version=bipm_default, limits="warn"):
def bipm_correction(
t: astropy.time.Time, bipm_version: str = bipm_default, limits: str = "warn"
) -> u.Quantity:
"""Compute the GPS clock corrections for times t."""
log.info(f"Applying TT(TAI) to TT({bipm_version}) clock correction (~27 us)")
tt2tai = 32.184 * 1e6 * u.us
@@ -214,7 +231,7 @@ def bipm_correction(t, bipm_version=bipm_default, limits="warn"):
)

@classmethod
def clear_registry(cls):
def clear_registry(cls) -> None:
"""Clear registry for ground-based observatories."""
cls._registry = {}
cls._alias_map = {}
@@ -229,7 +246,7 @@ def names(cls):
return cls._registry.keys()

@classmethod
def names_and_aliases(cls):
def names_and_aliases(cls) -> Dict[str, List[str]]:
"""List all observatories and their aliases"""
import pint.observatory.topo_obs # noqa
import pint.observatory.special_locations # noqa
@@ -241,15 +258,24 @@ def names_and_aliases(cls):
# setter methods that update the registries appropriately.

@property
def name(self):
def name(self) -> str:
"""Short name of the observatory.
This is the name used in TOA files and in the observatory registry.
"""
return self._name

@property
def aliases(self):
def aliases(self) -> List[str]:
"""List of aliases for the observatory.
These are short names also used to specify this observatory.
Includes ITOA and TEMPO codes, and any other common names.
"""
return self._aliases

@classmethod
def get(cls, name):
def get(cls, name: str) -> "Observatory":
"""Returns the Observatory instance for the specified name/alias.
If the name has not been defined, an error will be raised. Aside
@@ -303,9 +329,12 @@ def get(cls, name):
# Any which raise NotImplementedError below must be implemented in
# derived classes.

def earth_location_itrf(self, time=None):
"""Returns observatory geocentric position as an astropy
EarthLocation object. For observatories where this is not
def earth_location_itrf(
self, time: Optional[astropy.time.Time] = None
) -> Union[None, np.ndarray]:
"""Returns observatory geocentric position as an astropy EarthLocation object.
For observatories where this is not
relevant, None can be returned.
The location is in the International Terrestrial Reference Frame (ITRF).
@@ -319,8 +348,9 @@ def earth_location_itrf(self, time=None):
"""
return None

def get_gcrs(self, t, ephem=None):
"""Return position vector of observatory in GCRS
def get_gcrs(self, t: astropy.time.Time, ephem: Optional[str] = None):
"""Return position vector of observatory in GCRS.
t is an astropy.Time or array of astropy.Time objects
ephem is a link to an ephemeris file. Needed for SSB observatory
Returns a 3-vector of Quantities representing the position
@@ -329,14 +359,17 @@ def get_gcrs(self, t, ephem=None):
raise NotImplementedError

@property
def timescale(self):
"""Returns the timescale that TOAs from this observatory will be in,
once any clock corrections have been applied. This should be a
def timescale(self) -> str:
"""Returns the timescale that TOAs from this observatory will be in, once any clock corrections have been applied.
This should be a
string suitable to be passed directly to the scale argument of
astropy.time.Time()."""
raise NotImplementedError

def clock_corrections(self, t, limits="warn"):
def clock_corrections(
self, t: astropy.time.Time, limits: str = "warn"
) -> u.Quantity:
"""Compute clock corrections for a Time array.
Given an array-valued Time, return the clock corrections
@@ -356,7 +389,7 @@ def clock_corrections(self, t, limits="warn"):

return corr

def last_clock_correction_mjd(self):
def last_clock_correction_mjd(self) -> float:
"""Return the MJD of the last available clock correction.
Returns ``np.inf`` if no clock corrections are relevant.
@@ -365,6 +398,7 @@ def last_clock_correction_mjd(self):

if self.include_gps:
_load_gps_clock()
assert _gps_clock is not None
t = min(t, _gps_clock.last_correction_mjd())
if self.include_bipm:
_load_bipm_clock(self.bipm_version)
@@ -374,7 +408,13 @@ def last_clock_correction_mjd(self):
)
return t

def get_TDBs(self, t, method="default", ephem=None, options=None):
def get_TDBs(
self,
t: astropy.time.Time,
method: Union[str, Callable] = "default",
ephem: Optional[str] = None,
options: Optional[dict] = None,
):
"""This is a high level function for converting TOAs to TDB time scale.
Different method can be applied to obtain the result. Current supported
@@ -409,13 +449,13 @@ def get_TDBs(self, t, method="default", ephem=None, options=None):
t = Time([t])
if t.scale == "tdb":
return t
# Check the method. This pattern is from numpy minimize
meth = "_custom" if callable(method) else method.lower()
if options is None:
options = {}
if meth == "_custom":
if callable(method):
options = dict(options)
return method(t, **options)
else:
meth = method.lower()
if meth == "default":
return self._get_TDB_default(t, ephem)
elif meth == "ephemeris":
@@ -428,17 +468,17 @@ def get_TDBs(self, t, method="default", ephem=None, options=None):
else:
raise ValueError(f"Unknown method '{method}'.")

def _get_TDB_default(self, t, ephem):
def _get_TDB_default(self, t: astropy.time.Time, ephem: Optional[str]):
return t.tdb

def _get_TDB_ephem(self, t, ephem):
def _get_TDB_ephem(self, t: astropy.time.Time, ephem: Optional[str]):
"""Read the ephem TDB-TT column.
This column is provided by DE4XXt version of ephemeris.
"""
raise NotImplementedError

def posvel(self, t, ephem, group=None):
def posvel(self, t: astropy.time.Time, ephem: Optional[str], group=None) -> PosVel:
"""Return observatory position and velocity for the given times.
Position is relative to solar system barycenter; times are
@@ -451,7 +491,10 @@ def posvel(self, t, ephem, group=None):


def get_observatory(
name, include_gps=None, include_bipm=None, bipm_version=bipm_default
name: str,
include_gps: Optional[bool] = None,
include_bipm: Optional[bool] = None,
bipm_version: str = bipm_default,
):
"""Convenience function to get observatory object with options.
@@ -491,14 +534,14 @@ def get_observatory(
return Observatory.get(name)


def earth_location_distance(loc1, loc2):
def earth_location_distance(loc1: EarthLocation, loc2: EarthLocation) -> u.Quantity:
"""Compute the distance between two EarthLocations."""
return (
sum((u.Quantity(loc1.to_geocentric()) - u.Quantity(loc2.to_geocentric())) ** 2)
) ** 0.5


def compare_t2_observatories_dat(t2dir=None):
def compare_t2_observatories_dat(t2dir: Optional[str] = None) -> Dict[str, List[Dict]]:
"""Read a tempo2 observatories.dat file and compare with PINT
Produces a report including lines that can be added to PINT's
@@ -589,7 +632,7 @@ def compare_t2_observatories_dat(t2dir=None):
return report


def compare_tempo_obsys_dat(tempodir=None):
def compare_tempo_obsys_dat(tempodir: Optional[str] = None) -> Dict[str, List[Dict]]:
"""Read a tempo obsys.dat file and compare with PINT.
Produces a report including lines that can be added to PINT's
@@ -629,8 +672,8 @@ def compare_tempo_obsys_dat(tempodir=None):
y = float(line_io.read(15))
z = float(line_io.read(15))
line_io.read(2)
icoord = line_io.read(1).strip()
icoord = int(icoord) if icoord else 0
icoord_str = line_io.read(1).strip()
icoord = int(icoord_str) if icoord_str else 0
line_io.read(2)
obsnam = line_io.read(20).strip().lower()
tempo_code = line_io.read(1)
@@ -713,7 +756,7 @@ def convert_angle(x):
return report


def list_last_correction_mjds():
def list_last_correction_mjds() -> None:
"""Print out a list of the last MJD each clock correction is good for.
Each observatory lists the clock files it uses and their last dates,
@@ -744,7 +787,7 @@ def list_last_correction_mjds():
print(f" {c.friendly_name:<20} MISSING")


def update_clock_files(bipm_versions=None):
def update_clock_files(bipm_versions: Optional[List[str]] = None) -> None:
"""Obtain an up-to-date version of all clock files.
This up-to-date version will be stored in the Astropy cache;
@@ -786,13 +829,13 @@ def update_clock_files(bipm_versions=None):

# Both topo_obs and special_locations need this
def find_clock_file(
name,
format,
bogus_last_correction=False,
url_base=None,
clock_dir=None,
valid_beyond_ends=False,
):
name: str,
format: Literal["tempo", "tempo2"],
bogus_last_correction: bool = False,
url_base: Optional[str] = None,
clock_dir: Union[str, Path, None] = None,
valid_beyond_ends: bool = False,
) -> "ClockFile":
"""Locate and return a ClockFile in one of several places.
PINT looks for clock files in three places, in order:
141 changes: 84 additions & 57 deletions src/pint/observatory/topo_obs.py
Original file line number Diff line number Diff line change
@@ -17,12 +17,15 @@
--------
:mod:`pint.observatory.special_locations`
"""
import copy
import json
import os
from functools import cached_property
from pathlib import Path
import copy
from typing import Optional, Union, List, Any, Dict

import astropy.constants as c
import astropy.time
import astropy.units as u
import numpy as np
from astropy.coordinates import EarthLocation
@@ -36,13 +39,13 @@
NoClockCorrections,
Observatory,
bipm_default,
earth_location_distance,
find_clock_file,
get_observatory,
earth_location_distance,
)
from pint.pulsar_mjd import Time
from pint.solar_system_ephemerides import get_tdb_tt_ephem_geocenter, objPosVel_wrt_SSB
from pint.utils import has_astropy_unit, open_or_use
from pint.utils import has_astropy_unit, open_or_use, PosVel

# environment variables that can override clock location and observatory location
pint_obs_env_var = "PINT_OBS_OVERRIDE"
@@ -147,38 +150,63 @@
"""

tempo_code: Optional[str]
"""One-character TEMPO code."""
itoa_code: Optional[str]
"""Two-character ITOA code."""
location: EarthLocation
"""Location of the observatory."""
clock_files: List[str]
"""List of files to read for clock corrections. If empty, no clock corrections are applied."""
clock_fmt: str
"""Format of the clock files.
See :class:`pint.observatory.clock_file.ClockFile` for allowed values.
"""
bogus_last_correction: bool
"""Clock correction files include a bogus last correction.
This is common with TEMPO/TEMPO2 clock files since neither program does
a good job with times past the end ot the table. It makes detecting values
past the end of real calibration difficult if it's not marked as bogus.
"""
clock_dir: Optional[Union[str, Path]]
"""Where to look for the clock files."""
origin: Optional[str]
"""Documentation of the origin/author/date for the information."""

def __init__(
self,
name,
name: str,
*,
fullname=None,
tempo_code=None,
itoa_code=None,
aliases=None,
location=None,
fullname: Optional[str] = None,
tempo_code: Optional[str] = None,
itoa_code: Optional[str] = None,
aliases: Optional[List[str]] = None,
location: Optional[EarthLocation] = None,
itrf_xyz=None,
lat=None,
lon=None,
lat: Optional[float] = None,
lon: Optional[float] = None,
height=None,
clock_file="",
clock_fmt="tempo",
clock_dir=None,
include_gps=True,
include_bipm=True,
bipm_version=bipm_default,
origin=None,
overwrite=False,
bogus_last_correction=False,
clock_file: str = "",
clock_fmt: str = "tempo",
clock_dir: Union[str, Path, None] = None,
include_gps: bool = True,
include_bipm: bool = True,
bipm_version: str = bipm_default,
origin: Optional[str] = None,
overwrite: bool = False,
bogus_last_correction: bool = False,
):
input_values = [lat is not None, lon is not None, height is not None]
if sum(input_values) > 0 and sum(input_values) < 3:
if any(input_values) and not all(input_values):
raise ValueError("All of lat, lon, height are required for observatory")
input_values = [
location is not None,
itrf_xyz is not None,
(lat is not None and lon is not None and height is not None),
]
if sum(input_values) == 0:
if not any(input_values):
raise ValueError(
f"EarthLocation, ITRF coordinates, or lat/lon/height are required for observatory '{name}'"
)
@@ -209,11 +237,12 @@

# Save clock file info, the data will be read only if clock
# corrections for this site are requested.
self.clock_files = [clock_file] if isinstance(clock_file, str) else clock_file
self.clock_files = [c for c in self.clock_files if c != ""]
self.clock_fmt = clock_fmt
clock_files: List[str] = (
[clock_file] if isinstance(clock_file, str) else clock_file
)
self.clock_files: List[str] = [c for c in clock_files if c != ""]
self.clock_fmt: str = clock_fmt
self.clock_dir = clock_dir
self._clock = None # The ClockFile objects, will be read on demand

# If using TEMPO time.dat we need to know the 1-char tempo-style
# observatory code.
@@ -248,7 +277,7 @@
overwrite=overwrite,
)

def __repr__(self):
def __repr__(self) -> str:
aliases = [f"'{x}'" for x in self.aliases]
origin = (
f"{self.fullname}\n{self.origin}"
@@ -258,10 +287,10 @@
return f"TopoObs('{self.name}' ({','.join(aliases)}) at [{self.location.x}, {self.location.y} {self.location.z}]:\n{origin})"

@property
def timescale(self):
def timescale(self) -> str:
return "utc"

def get_dict(self):
def get_dict(self) -> Dict[str, Dict[str, Any]]:
"""Return as a dict with limited/changed info"""
# start with the default __dict__
# copy some attributes to rename them and remove those that aren't needed for initialization
@@ -276,12 +305,12 @@
output["itrf_xyz"] = [x.to_value(u.m) for x in self.location.geocentric]
return {self.name: output}

def get_json(self):
"""Return as a JSON string"""
def get_json(self) -> str:
"""Return as a JSON string."""
return json.dumps(self.get_dict())

def separation(self, other, method="cartesian"):
"""Return separation between two TopoObs objects
def separation(self, other: "TopoObs", method: str = "cartesian") -> u.Quantity:
"""Return separation between two TopoObs objects.
Parameters
----------
@@ -312,30 +341,30 @@
)
return (c.R_earth * dsigma).to(u.m, equivalencies=u.dimensionless_angles())

def earth_location_itrf(self, time=None):
def earth_location_itrf(self, time=None) -> EarthLocation:
return self.location

def _load_clock_corrections(self):
if self._clock is not None:
return
self._clock = []
@cached_property
def _clock(self) -> list:
clock = []
for cf in self.clock_files:
if cf == "":
continue
kwargs = dict(bogus_last_correction=self.bogus_last_correction)
if isinstance(cf, dict):
kwargs.update(cf)
cf = kwargs.pop("name")
self._clock.append(
clock.append(
find_clock_file(
cf,
format=self.clock_fmt,
clock_dir=self.clock_dir,
**kwargs,
)
)
return clock

def clock_corrections(self, t, limits="warn"):
def clock_corrections(self, t: Time, limits: str = "warn") -> u.Quantity:
"""Compute the total clock corrections,
Parameters
@@ -344,17 +373,16 @@
The time when the clock correcions are applied.
"""

corr = super().clock_corrections(t, limits=limits)
# Read clock file if necessary
self._load_clock_corrections()
corr: u.Quantity = super().clock_corrections(t, limits=limits)
if self._clock:
log.info(
f"Applying observatory clock corrections for observatory='{self.name}'."
)
for clock in self._clock:
corr += clock.evaluate(t, limits=limits)

elif self.clock_files:
# clock_files is not empty, but no clock corrections found
# FIXME: what if only some were found?
msg = f"No clock corrections found for observatory {self.name} taken from file {self.clock_files}"
if limits == "warn":
log.warning(msg)
@@ -365,19 +393,18 @@
log.info(f"Observatory {self.name} requires no clock corrections.")
return corr

def last_clock_correction_mjd(self):
def last_clock_correction_mjd(self) -> float:
"""Return the MJD of the last clock correction.
Combines constraints based on Earth orientation parameters and on the
available clock corrections specific to the telescope.
"""
t = super().last_clock_correction_mjd()
self._load_clock_corrections()
for clock in self._clock:
t = min(t, clock.last_correction_mjd())
return t

def _get_TDB_ephem(self, t, ephem):
def _get_TDB_ephem(self, t: Time, ephem: Optional[str]) -> Time:
"""Read the ephem TDB-TT column.
This column is provided by DE4XXt version of ephemeris. This function is only
@@ -389,8 +416,8 @@
# Topocenter to Geocenter
# Since earth velocity is not going to change a lot in 3ms. The
# differences between TT and TDB can be ignored.
earth_pv = objPosVel_wrt_SSB("earth", t.tdb, ephem)
obs_geocenter_pv = gcrs_posvel_from_itrf(
earth_pv: PosVel = objPosVel_wrt_SSB("earth", t.tdb, ephem)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the advantage to typing of this sort? Is this to prevent any changes to the API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tells the type checker to signal a problem if the function does not actually return a PosVel, or if the surrounding code uses the value in a way incompatible with a PosVel. More, it informs the reader what type this has, in case they don't know off the top of their head what type to expect. In this case it seemed useful because I wasn't ready to go digging about and annotate those two functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, that makes sense.

obs_geocenter_pv: PosVel = gcrs_posvel_from_itrf(
self.earth_location_itrf(), t, obsname=self.name
)
# NOTE
@@ -406,7 +433,7 @@
location=self.earth_location_itrf(),
)

def get_gcrs(self, t, ephem=None):
def get_gcrs(self, t: astropy.time.Time, ephem: Optional[str] = None):
"""Return position vector of TopoObs in GCRS
Parameters
@@ -418,22 +445,22 @@
np.array
a 3-vector of Quantities representing the position in GCRS coordinates.
"""
obs_geocenter_pv = gcrs_posvel_from_itrf(
obs_geocenter_pv: PosVel = gcrs_posvel_from_itrf(

Check warning on line 448 in src/pint/observatory/topo_obs.py

Codecov / codecov/patch

src/pint/observatory/topo_obs.py#L448

Added line #L448 was not covered by tests
self.earth_location_itrf(), t, obsname=self.name
)
return obs_geocenter_pv.pos

def posvel(self, t, ephem, group=None):
def posvel(self, t: astropy.time.Time, ephem: Optional[str], group=None) -> PosVel:
if t.isscalar:
t = Time([t])
earth_pv = objPosVel_wrt_SSB("earth", t, ephem)
obs_geocenter_pv = gcrs_posvel_from_itrf(
earth_pv: PosVel = objPosVel_wrt_SSB("earth", t, ephem)
obs_geocenter_pv: PosVel = gcrs_posvel_from_itrf(
self.earth_location_itrf(), t, obsname=self.name
)
return obs_geocenter_pv + earth_pv


def export_all_clock_files(directory):
def export_all_clock_files(directory: Union[str, Path]) -> None:
"""Export all clock files PINT is using.
This will export all the clock files PINT is using - every clock file used
@@ -465,7 +492,7 @@
clock.export(directory / Path(clock.filename).name)


def load_observatories(filename=observatories_json, overwrite=False):
def load_observatories(filename=observatories_json, overwrite: bool = False) -> None:
"""Load observatory definitions from JSON and create :class:`pint.observatory.topo_obs.TopoObs` objects, registering them
Set `overwrite` to ``True`` if you want to re-read a file with updated definitions.
@@ -499,7 +526,7 @@
TopoObs(name=obsname, **obsdict)


def load_observatories_from_usual_locations(clear=False):
def load_observatories_from_usual_locations(clear: bool = False) -> None:
"""Load observatories from the default JSON file as well as ``$PINT_OBS_OVERRIDE``, optionally clearing the registry
Running with ``clear=True`` will return PINT to the state it is on import.