From 5cbd67da4c60e3a262326e51bd6aa9220c658f7d Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Thu, 11 Jul 2024 10:50:50 +0200 Subject: [PATCH 01/31] add cmwavex --- src/pint/models/cmwavex.py | 387 +++++++++++++++++++++++++++++++++++++ src/pint/models/dmwavex.py | 4 +- 2 files changed, 389 insertions(+), 2 deletions(-) create mode 100644 src/pint/models/cmwavex.py diff --git a/src/pint/models/cmwavex.py b/src/pint/models/cmwavex.py new file mode 100644 index 000000000..465700565 --- /dev/null +++ b/src/pint/models/cmwavex.py @@ -0,0 +1,387 @@ +"""Chromatic variations expressed as a sum of sinusoids.""" + +import astropy.units as u +import numpy as np +from loguru import logger as log +from warnings import warn + +from pint.models.parameter import MJDParameter, prefixParameter +from pint.models.timing_model import MissingParameter +from pint.models.chromatic_model import Chromatic, cmu +from pint import DMconst + + +class CMWaveX(Chromatic): + """ + Fourier representation of chromatic variations. + + Used for decomposition of chromatic noise into a series of sine/cosine components with the amplitudes as fitted parameters. + + Parameters supported: + + .. paramtable:: + :class: pint.models.cmwavex.CMWaveX + + To set up a CMWaveX model, users can use the `pint.utils` function `cmwavex_setup()` with either a list of frequencies or a choice + of harmonics of a base frequency determined by 2 * pi /Timespan + """ + + register = True + category = "cmwavex" + + def __init__(self): + super().__init__() + self.add_param( + MJDParameter( + name="CMWXEPOCH", + description="Reference epoch for Fourier representation of chromatic noise", + time_scale="tdb", + tcb2tdb_scale_factor=u.Quantity(1), + ) + ) + self.add_cmwavex_component(0.1, index=1, cmwxsin=0, cmwxcos=0, frozen=False) + self.set_special_params(["CMWXFREQ_0001", "CMWXSIN_0001", "CMWXCOS_0001"]) + self.cm_value_funcs += [self.cmwavex_cm] + self.delay_funcs_component += [self.cmwavex_delay] + + def add_cmwavex_component( + self, cmwxfreq, index=None, cmwxsin=0, cmwxcos=0, frozen=True + ): + """ + Add CMWaveX component + + Parameters + ---------- + + cmwxfreq : float or astropy.quantity.Quantity + Base frequency for CMWaveX component + index : int, None + Interger label for CMWaveX component. If None, will increment largest used index by 1. + cmwxsin : float or astropy.quantity.Quantity + Sine amplitude for CMWaveX component + cmwxcos : float or astropy.quantity.Quantity + Cosine amplitude for CMWaveX component + frozen : iterable of bool or bool + Indicates whether CMWaveX parameters will be fit + + Returns + ------- + + index : int + Index that has been assigned to new CMWaveX component + """ + + #### If index is None, increment the current max CMWaveX index by 1. Increment using CMWXFREQ + if index is None: + dct = self.get_prefix_mapping_component("CMWXFREQ_") + index = np.max(list(dct.keys())) + 1 + i = f"{int(index):04d}" + + if int(index) in self.get_prefix_mapping_component("CMWXFREQ_"): + raise ValueError( + f"Index '{index}' is already in use in this model. Please choose another" + ) + + if isinstance(cmwxsin, u.quantity.Quantity): + cmwxsin = cmwxsin.to_value(cmu) + if isinstance(cmwxcos, u.quantity.Quantity): + cmwxcos = cmwxcos.to_value(cmu) + if isinstance(cmwxfreq, u.quantity.Quantity): + cmwxfreq = cmwxfreq.to_value(1 / u.d) + self.add_param( + prefixParameter( + name=f"CMWXFREQ_{i}", + description="Component frequency for Fourier representation of chromatic noise", + units="1/d", + value=cmwxfreq, + parameter_type="float", + tcb2tdb_scale_factor=u.Quantity(1), + ) + ) + self.add_param( + prefixParameter( + name=f"CMWXSIN_{i}", + description="Sine amplitudes for Fourier representation of chromatic noise", + units=cmu, + value=cmwxsin, + frozen=frozen, + parameter_type="float", + tcb2tdb_scale_factor=DMconst, + ) + ) + self.add_param( + prefixParameter( + name=f"CMWXCOS_{i}", + description="Cosine amplitudes for Fourier representation of chromatic noise", + units=cmu, + value=cmwxcos, + frozen=frozen, + parameter_type="float", + tcb2tdb_scale_factor=DMconst, + ) + ) + self.setup() + self.validate() + return index + + def add_cmwavex_components( + self, cmwxfreqs, indices=None, cmwxsins=0, cmwxcoses=0, frozens=True + ): + """ + Add CMWaveX components with specified base frequencies + + Parameters + ---------- + + cmwxfreqs : iterable of float or astropy.quantity.Quantity + Base frequencies for CMWaveX components + indices : iterable of int, None + Interger labels for CMWaveX components. If None, will increment largest used index by 1. + cmwxsins : iterable of float or astropy.quantity.Quantity + Sine amplitudes for CMWaveX components + cmwxcoses : iterable of float or astropy.quantity.Quantity + Cosine amplitudes for CMWaveX components + frozens : iterable of bool or bool + Indicates whether sine and cosine amplitudes of CMwavex components will be fit + + Returns + ------- + + indices : list + Indices that have been assigned to new CMWaveX components + """ + + if indices is None: + indices = [None] * len(cmwxfreqs) + cmwxsins = np.atleast_1d(cmwxsins) + cmwxcoses = np.atleast_1d(cmwxcoses) + if len(cmwxsins) == 1: + cmwxsins = np.repeat(cmwxsins, len(cmwxfreqs)) + if len(cmwxcoses) == 1: + cmwxcoses = np.repeat(cmwxcoses, len(cmwxfreqs)) + if len(cmwxsins) != len(cmwxfreqs): + raise ValueError( + f"Number of base frequencies {len(cmwxfreqs)} doesn't match number of sine ampltudes {len(cmwxsins)}" + ) + if len(cmwxcoses) != len(cmwxfreqs): + raise ValueError( + f"Number of base frequencies {len(cmwxfreqs)} doesn't match number of cosine ampltudes {len(cmwxcoses)}" + ) + frozens = np.atleast_1d(frozens) + if len(frozens) == 1: + frozens = np.repeat(frozens, len(cmwxfreqs)) + if len(frozens) != len(cmwxfreqs): + raise ValueError( + "Number of base frequencies must match number of frozen values" + ) + #### If indices is None, increment the current max CMWaveX index by 1. Increment using CMWXFREQ + dct = self.get_prefix_mapping_component("CMWXFREQ_") + last_index = np.max(list(dct.keys())) + added_indices = [] + for cmwxfreq, index, cmwxsin, cmwxcos, frozen in zip( + cmwxfreqs, indices, cmwxsins, cmwxcoses, frozens + ): + if index is None: + index = last_index + 1 + last_index += 1 + elif index in list(dct.keys()): + raise ValueError( + f"Attempting to insert CMWXFREQ_{index:04d} but it already exists" + ) + added_indices.append(index) + i = f"{int(index):04d}" + + if int(index) in dct: + raise ValueError( + f"Index '{index}' is already in use in this model. Please choose another" + ) + if isinstance(cmwxfreq, u.quantity.Quantity): + cmwxfreq = cmwxfreq.to_value(u.d**-1) + if isinstance(cmwxsin, u.quantity.Quantity): + cmwxsin = cmwxsin.to_value(cmu) + if isinstance(cmwxcos, u.quantity.Quantity): + cmwxcos = cmwxcos.to_value(cmu) + log.trace(f"Adding CMWXSIN_{i} and CMWXCOS_{i} at frequency CMWXFREQ_{i}") + self.add_param( + prefixParameter( + name=f"CMWXFREQ_{i}", + description="Component frequency for Fourier representation of chromatic noise", + units="1/d", + value=cmwxfreq, + parameter_type="float", + tcb2tdb_scale_factor=u.Quantity(1), + ) + ) + self.add_param( + prefixParameter( + name=f"CMWXSIN_{i}", + description="Sine amplitude for Fourier representation of chromatic noise", + units=cmu, + value=cmwxsin, + parameter_type="float", + frozen=frozen, + tcb2tdb_scale_factor=DMconst, + ) + ) + self.add_param( + prefixParameter( + name=f"CMWXCOS_{i}", + description="Cosine amplitude for Fourier representation of chromatic noise", + units=cmu, + value=cmwxcos, + parameter_type="float", + frozen=frozen, + tcb2tdb_scale_factor=DMconst, + ) + ) + self.setup() + self.validate() + return added_indices + + def remove_cmwavex_component(self, index): + """ + Remove all CMWaveX components associated with a given index or list of indices + + Parameters + ---------- + index : float, int, list, np.ndarray + Number or list/array of numbers corresponding to CMWaveX indices to be removed from model. + """ + + if isinstance(index, (int, float, np.int64)): + indices = [index] + elif isinstance(index, (list, set, np.ndarray)): + indices = index + else: + raise TypeError( + f"index most be a float, int, set, list, or array - not {type(index)}" + ) + for index in indices: + index_rf = f"{int(index):04d}" + for prefix in ["CMWXFREQ_", "CMWXSIN_", "CMWXCOS_"]: + self.remove_param(prefix + index_rf) + self.validate() + + def get_indices(self): + """ + Returns an array of intergers corresponding to CMWaveX component parameters using CMWXFREQs + + Returns + ------- + inds : np.ndarray + Array of CMWaveX indices in model. + """ + inds = [int(p.split("_")[-1]) for p in self.params if "CMWXFREQ_" in p] + return np.array(inds) + + # Initialize setup + def setup(self): + super().setup() + # Get CMWaveX mapping and register CMWXSIN and CMWXCOS derivatives + for prefix_par in self.get_params_of_type("prefixParameter"): + if prefix_par.startswith("CMWXSIN_"): + self.register_deriv_funcs(self.d_delay_d_cmparam, prefix_par) + self.register_cm_deriv_funcs(self.d_cm_d_CMWXSIN, prefix_par) + if prefix_par.startswith("CMWXCOS_"): + self.register_deriv_funcs(self.d_delay_d_cmparam, prefix_par) + self.register_cm_deriv_funcs(self.d_cm_d_CMWXCOS, prefix_par) + self.cmwavex_freqs = list( + self.get_prefix_mapping_component("CMWXFREQ_").keys() + ) + self.num_cmwavex_freqs = len(self.cmwavex_freqs) + + def validate(self): + # Validate all the CMWaveX parameters + super().validate() + self.setup() + CMWXFREQ_mapping = self.get_prefix_mapping_component("CMWXFREQ_") + CMWXSIN_mapping = self.get_prefix_mapping_component("CMWXSIN_") + CMWXCOS_mapping = self.get_prefix_mapping_component("CMWXCOS_") + if CMWXFREQ_mapping.keys() != CMWXSIN_mapping.keys(): + raise ValueError( + "CMWXFREQ_ parameters do not match CMWXSIN_ parameters." + "Please check your prefixed parameters" + ) + if CMWXFREQ_mapping.keys() != CMWXCOS_mapping.keys(): + raise ValueError( + "CMWXFREQ_ parameters do not match CMWXCOS_ parameters." + "Please check your prefixed parameters" + ) + # if len(CMWXFREQ_mapping.keys()) != len(CMWXSIN_mapping.keys()): + # raise ValueError( + # "The number of CMWXFREQ_ parameters do not match the number of CMWXSIN_ parameters." + # "Please check your prefixed parameters" + # ) + # if len(CMWXFREQ_mapping.keys()) != len(CMWXCOS_mapping.keys()): + # raise ValueError( + # "The number of CMWXFREQ_ parameters do not match the number of CMWXCOS_ parameters." + # "Please check your prefixed parameters" + # ) + if CMWXSIN_mapping.keys() != CMWXCOS_mapping.keys(): + raise ValueError( + "CMWXSIN_ parameters do not match CMWXCOS_ parameters." + "Please check your prefixed parameters" + ) + if len(CMWXSIN_mapping.keys()) != len(CMWXCOS_mapping.keys()): + raise ValueError( + "The number of CMWXSIN_ and CMWXCOS_ parameters do not match" + "Please check your prefixed parameters" + ) + wfreqs = np.zeros(len(CMWXFREQ_mapping)) + for j, index in enumerate(CMWXFREQ_mapping): + if (getattr(self, f"CMWXFREQ_{index:04d}").value == 0) or ( + getattr(self, f"CMWXFREQ_{index:04d}").quantity is None + ): + raise ValueError( + f"CMWXFREQ_{index:04d} is zero or None. Please check your prefixed parameters" + ) + if getattr(self, f"CMWXFREQ_{index:04d}").value < 0.0: + warn(f"Frequency CMWXFREQ_{index:04d} is negative") + wfreqs[j] = getattr(self, f"CMWXFREQ_{index:04d}").value + wfreqs.sort() + # if np.any(np.diff(wfreqs) <= (1.0 / (2.0 * 364.25))): + # warn("Frequency resolution is greater than 1/yr") + if self.CMWXEPOCH.value is None and self._parent is not None: + if self._parent.PEPOCH.value is None: + raise MissingParameter( + "CMWXEPOCH or PEPOCH are required if CMWaveX is being used" + ) + else: + self.CMWXEPOCH.quantity = self._parent.PEPOCH.quantity + + def validate_toas(self, toas): + return super().validate_toas(toas) + + def cmwavex_cm(self, toas): + total_cm = np.zeros(toas.ntoas) * cmu + cmwave_freqs = self.get_prefix_mapping_component("CMWXFREQ_") + cmwave_sins = self.get_prefix_mapping_component("CMWXSIN_") + cmwave_cos = self.get_prefix_mapping_component("CMWXCOS_") + + base_phase = toas.table["tdbld"].data * u.d - self.CMWXEPOCH.value * u.d + for idx, param in cmwave_freqs.items(): + freq = getattr(self, param).quantity + cmwxsin = getattr(self, cmwave_sins[idx]).quantity + cmwxcos = getattr(self, cmwave_cos[idx]).quantity + arg = 2.0 * np.pi * freq * base_phase + total_cm += cmwxsin * np.sin(arg.value) + cmwxcos * np.cos(arg.value) + return total_cm + + def cmwavex_delay(self, toas, acc_delay=None): + return self.chromatic_type_delay(toas) + + def d_cm_d_CMWXSIN(self, toas, param, acc_delay=None): + par = getattr(self, param) + freq = getattr(self, f"CMWXFREQ_{int(par.index):04d}").quantity + base_phase = toas.table["tdbld"].data * u.d - self.CMWXEPOCH.value * u.d + arg = 2.0 * np.pi * freq * base_phase + deriv = np.sin(arg.value) + return deriv * cmu / par.units + + def d_cm_d_CMWXCOS(self, toas, param, acc_delay=None): + par = getattr(self, param) + freq = getattr(self, f"CMWXFREQ_{int(par.index):04d}").quantity + base_phase = toas.table["tdbld"].data * u.d - self.CMWXEPOCH.value * u.d + arg = 2.0 * np.pi * freq * base_phase + deriv = np.cos(arg.value) + return deriv * cmu / par.units diff --git a/src/pint/models/dmwavex.py b/src/pint/models/dmwavex.py index 258eb3556..368862391 100644 --- a/src/pint/models/dmwavex.py +++ b/src/pint/models/dmwavex.py @@ -271,7 +271,7 @@ def get_indices(self): inds : np.ndarray Array of DMWaveX indices in model. """ - inds = [int(p.split("_")[-1]) for p in self.params if "WXFREQ_" in p] + inds = [int(p.split("_")[-1]) for p in self.params if "DMWXFREQ_" in p] return np.array(inds) # Initialize setup @@ -299,7 +299,7 @@ def validate(self): DMWXCOS_mapping = self.get_prefix_mapping_component("DMWXCOS_") if DMWXFREQ_mapping.keys() != DMWXSIN_mapping.keys(): raise ValueError( - "WXFREQ_ parameters do not match DMWXSIN_ parameters." + "DMWXFREQ_ parameters do not match DMWXSIN_ parameters." "Please check your prefixed parameters" ) if DMWXFREQ_mapping.keys() != DMWXCOS_mapping.keys(): From e9b0d3553d25965680b7f7ccb3b17bbad2b6234c Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Thu, 11 Jul 2024 12:35:59 +0200 Subject: [PATCH 02/31] init --- src/pint/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pint/models/__init__.py b/src/pint/models/__init__.py index ab7820027..4c77ce9e1 100644 --- a/src/pint/models/__init__.py +++ b/src/pint/models/__init__.py @@ -26,6 +26,7 @@ from pint.models.binary_ddk import BinaryDDK from pint.models.binary_ell1 import BinaryELL1, BinaryELL1H, BinaryELL1k from pint.models.chromatic_model import ChromaticCM +from pint.models.cmwavex import CMWaveX from pint.models.dispersion_model import ( DispersionDM, DispersionDMX, From 557ab55c4b19c821a22afef44862232707a1578f Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Thu, 11 Jul 2024 12:38:33 +0200 Subject: [PATCH 03/31] test_dmwavex --- tests/test_dmwavex.py | 45 ++++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/tests/test_dmwavex.py b/tests/test_dmwavex.py index 0b37877aa..5fdaf0cba 100644 --- a/tests/test_dmwavex.py +++ b/tests/test_dmwavex.py @@ -11,23 +11,23 @@ import pytest import astropy.units as u -par = """ - PSR B1937+21 - LAMBDA 301.9732445337270 - BETA 42.2967523367957 - PMLAMBDA -0.0175 - PMBETA -0.3971 - PX 0.1515 - POSEPOCH 55321.0000 - F0 641.9282333345536244 1 0.0000000000000132 - F1 -4.330899370129D-14 1 2.149749089617D-22 - PEPOCH 55321.000000 - DM 71.016633 - UNITS TDB -""" - def test_dmwavex(): + par = """ + PSR B1937+21 + LAMBDA 301.9732445337270 + BETA 42.2967523367957 + PMLAMBDA -0.0175 + PMBETA -0.3971 + PX 0.1515 + POSEPOCH 55321.0000 + F0 641.9282333345536244 1 0.0000000000000132 + F1 -4.330899370129D-14 1 2.149749089617D-22 + PEPOCH 55321.000000 + DM 71.016633 + UNITS TDB + """ + m = get_model(StringIO(par)) with pytest.raises(ValueError): @@ -111,6 +111,21 @@ def test_dmwavex_badpar(): def test_add_dmwavex(): + par = """ + PSR B1937+21 + LAMBDA 301.9732445337270 + BETA 42.2967523367957 + PMLAMBDA -0.0175 + PMBETA -0.3971 + PX 0.1515 + POSEPOCH 55321.0000 + F0 641.9282333345536244 1 0.0000000000000132 + F1 -4.330899370129D-14 1 2.149749089617D-22 + PEPOCH 55321.000000 + DM 71.016633 + UNITS TDB + """ + m = get_model(StringIO(par)) idxs = dmwavex_setup(m, 3600, n_freqs=5) From fd87992f99227d2fa2dfac29b26bba854aaa3b36 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Thu, 11 Jul 2024 12:41:42 +0200 Subject: [PATCH 04/31] cmwavex_setup --- src/pint/utils.py | 102 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 98 insertions(+), 4 deletions(-) diff --git a/src/pint/utils.py b/src/pint/utils.py index 6d59055f1..bb80b37ee 100644 --- a/src/pint/utils.py +++ b/src/pint/utils.py @@ -1570,7 +1570,7 @@ def dmwavex_setup( model : pint.models.timing_model.TimingModel T_span : float, astropy.quantity.Quantity Time span used to calculate nyquist frequency when using freqs - Time span used to calculate WaveX frequencies when using n_freqs + Time span used to calculate DMWaveX frequencies when using n_freqs Usually to be set as the length of the timing baseline the model is being used for freqs : iterable of float or astropy.quantity.Quantity, None User inputed base frequencies @@ -1643,6 +1643,100 @@ def dmwavex_setup( return model.components["DMWaveX"].get_indices() +def cmwavex_setup( + model: "pint.models.TimingModel", + T_span: Union[float, u.Quantity], + freqs: Optional[Iterable[Union[float, u.Quantity]]] = None, + n_freqs: Optional[int] = None, + freeze_params: bool = False, +) -> List[int]: + """ + Set-up a CMWaveX model based on either an array of user-provided frequencies or the wave number + frequency calculation. Sine and Cosine amplitudes are initially set to zero + + User specifies T_span and either freqs or n_freqs. This function assumes that the timing model does not already + have any CMWaveX components. See add_cmwavex_component() or add_cmwavex_components() to add components + to an existing CMWaveX model. + + Parameters + ---------- + + model : pint.models.timing_model.TimingModel + T_span : float, astropy.quantity.Quantity + Time span used to calculate nyquist frequency when using freqs + Time span used to calculate CMWaveX frequencies when using n_freqs + Usually to be set as the length of the timing baseline the model is being used for + freqs : iterable of float or astropy.quantity.Quantity, None + User inputed base frequencies + n_freqs : int, None + Number of wave frequencies to calculate using the equation: freq_n = 2 * pi * n / T_span + Where n is the wave number, and T_span is the total time span of the toas in the fitter object + freeze_params : bool, optional + Whether the new parameters should be frozen + + Returns + ------- + + indices : list + Indices that have been assigned to new WaveX components + """ + from pint.models.cmwavex import CMWaveX + + if (freqs is None) and (n_freqs is None): + raise ValueError( + "CMWaveX component base frequencies are not specified. " + "Please input either freqs or n_freqs" + ) + + if (freqs is not None) and (n_freqs is not None): + raise ValueError( + "Both freqs and n_freqs are specified. Only one or the other should be used" + ) + + if n_freqs is not None and n_freqs <= 0: + raise ValueError("Must use a non-zero number of wave frequencies") + + model.add_component(CMWaveX()) + if isinstance(T_span, u.quantity.Quantity): + T_span.to(u.d) + else: + T_span *= u.d + + nyqist_freq = 1.0 / (2.0 * T_span) + if freqs is not None: + if isinstance(freqs, u.quantity.Quantity): + freqs.to(u.d**-1) + else: + freqs *= u.d**-1 + if len(freqs) == 1: + model.CMWXFREQ_0001.quantity = freqs + else: + freqs = np.array(freqs) + freqs.sort() + if min(np.diff(freqs)) < nyqist_freq: + warnings.warn( + "CMWaveX frequency spacing is finer than frequency resolution of data" + ) + model.CMWXFREQ_0001.quantity = freqs[0] + model.components["CMWaveX"].add_cmwavex_components(freqs[1:]) + + if n_freqs is not None: + if n_freqs == 1: + wave_freq = 1 / T_span + model.CMWXFREQ_0001.quantity = wave_freq + else: + wave_numbers = np.arange(1, n_freqs + 1) + wave_freqs = wave_numbers / T_span + model.CMWXFREQ_0001.quantity = wave_freqs[0] + model.components["CMWaveX"].add_cmwavex_components(wave_freqs[1:]) + + for p in model.params: + if p.startswith("CMWXSIN") or p.startswith("CMWXCOS"): + model[p].frozen = freeze_params + + return model.components["CMWaveX"].get_indices() + + def _translate_wave_freqs(om: Union[float, u.Quantity], k: int) -> u.Quantity: """ Use Wave model WAVEOM parameter to calculate a WaveX WXFREQ_ frequency parameter for wave number k @@ -2396,12 +2490,12 @@ def info_string( else: s += f"{os.linesep}Comment: {comment}" - if (prefix_string is not None) and (len(prefix_string) > 0): + if prefix_string is not None and prefix_string != "": s = os.linesep.join([prefix_string + x for x in s.splitlines()]) return s -def list_parameters(class_: Type = None) -> List[Dict[str, Union[str, List]]]: +def list_parameters(class_: Optional[Type] = None) -> List[Dict[str, Union[str, List]]]: """List parameters understood by PINT. Parameters @@ -3039,7 +3133,7 @@ def _get_wx2pl_lnlike( from pint.models.noise_model import powerlaw from pint import DMconst - assert component_name in ["WaveX", "DMWaveX"] + assert component_name in {"WaveX", "DMWaveX"} prefix = "WX" if component_name == "WaveX" else "DMWX" idxs = np.array(model.components[component_name].get_indices()) From 86a83b97742059a2b070f5130335cd1169df469b Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Thu, 11 Jul 2024 12:41:59 +0200 Subject: [PATCH 05/31] test_cmwavex --- tests/test_cmwavex.py | 177 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 tests/test_cmwavex.py diff --git a/tests/test_cmwavex.py b/tests/test_cmwavex.py new file mode 100644 index 000000000..d0f2652fb --- /dev/null +++ b/tests/test_cmwavex.py @@ -0,0 +1,177 @@ +from io import StringIO + +import numpy as np + +from pint.models import get_model +from pint.fitter import Fitter +from pint.simulation import make_fake_toas_uniform +from pint.utils import cmwavex_setup +from pint.models.chromatic_model import cmu + +import pytest +import astropy.units as u + + +def test_cmwavex(): + par = """` + PSR B1937+21 + LAMBDA 301.9732445337270 + BETA 42.2967523367957 + PMLAMBDA -0.0175 + PMBETA -0.3971 + PX 0.1515 + POSEPOCH 55321.0000 + F0 641.9282333345536244 1 0.0000000000000132 + F1 -4.330899370129D-14 1 2.149749089617D-22 + PEPOCH 55321.000000 + DM 71.016633 + CM 0.1 + TNCHROMIDX 4 + UNITS TDB + """ + m = get_model(StringIO(par)) + + with pytest.raises(ValueError): + idxs = cmwavex_setup(m, 3600) + + idxs = cmwavex_setup(m, 3600, n_freqs=5) + + assert "CMWaveX" in m.components + assert m.num_cmwavex_freqs == len(idxs) + + m.components["CMWaveX"].remove_cmwavex_component(5) + assert m.num_cmwavex_freqs == len(idxs) - 1 + + t = make_fake_toas_uniform(54000, 56000, 200, m, add_noise=True) + + ftr = Fitter.auto(t, m) + ftr.fit_toas() + + assert ftr.resids.reduced_chi2 < 2 + + +def test_cmwavex_badpar(): + with pytest.raises(ValueError): + par = """ + PSR B1937+21 + LAMBDA 301.9732445337270 + BETA 42.2967523367957 + PMLAMBDA -0.0175 + PMBETA -0.3971 + PX 0.1515 + POSEPOCH 55321.0000 + F0 641.9282333345536244 1 0.0000000000000132 + F1 -4.330899370129D-14 1 2.149749089617D-22 + PEPOCH 55321.000000 + DM 71.016633 + CM 0.1 + TNCHROMIDX 4 + UNITS TDB + CMWXFREQ_0001 0.01 + CMWXSIN_0001 0 + CMWXSIN_0002 0 + """ + get_model(StringIO(par)) + + with pytest.raises(ValueError): + par = """ + PSR B1937+21 + LAMBDA 301.9732445337270 + BETA 42.2967523367957 + PMLAMBDA -0.0175 + PMBETA -0.3971 + PX 0.1515 + POSEPOCH 55321.0000 + F0 641.9282333345536244 1 0.0000000000000132 + F1 -4.330899370129D-14 1 2.149749089617D-22 + PEPOCH 55321.000000 + DM 71.016633 + CM 0.1 + TNCHROMIDX 4 + UNITS TDB + CMWXFREQ_0001 0.01 + CMWXCOS_0001 0 + CMWXCOS_0002 0 + """ + get_model(StringIO(par)) + + with pytest.raises(ValueError): + par = """ + PSR B1937+21 + LAMBDA 301.9732445337270 + BETA 42.2967523367957 + PMLAMBDA -0.0175 + PMBETA -0.3971 + PX 0.1515 + POSEPOCH 55321.0000 + F0 641.9282333345536244 1 0.0000000000000132 + F1 -4.330899370129D-14 1 2.149749089617D-22 + PEPOCH 55321.000000 + DM 71.016633 + CM 0.1 + TNCHROMIDX 4 + UNITS TDB + CMWXFREQ_0001 0.00 + CMWXCOS_0001 0 + """ + get_model(StringIO(par)) + + +def test_add_cmwavex(): + par = """ + PSR B1937+21 + LAMBDA 301.9732445337270 + BETA 42.2967523367957 + PMLAMBDA -0.0175 + PMBETA -0.3971 + PX 0.1515 + POSEPOCH 55321.0000 + F0 641.9282333345536244 1 0.0000000000000132 + F1 -4.330899370129D-14 1 2.149749089617D-22 + PEPOCH 55321.000000 + DM 71.016633 + CM 0.1 + TNCHROMIDX 4 + UNITS TDB + """ + m = get_model(StringIO(par)) + idxs = cmwavex_setup(m, 3600, n_freqs=5) + + with pytest.raises(ValueError): + m.components["CMWaveX"].add_cmwavex_component(1, index=5, cmwxsin=0, cmwxcos=0) + + m.components["CMWaveX"].add_cmwavex_component(1, index=6, cmwxsin=0, cmwxcos=0) + assert m.num_cmwavex_freqs == len(idxs) + 1 + + m.components["CMWaveX"].add_cmwavex_component( + 1 / u.day, index=7, cmwxsin=0 * cmu, cmwxcos=0 * cmu + ) + assert m.num_cmwavex_freqs == len(idxs) + 2 + + m.components["CMWaveX"].add_cmwavex_component(2 / u.day) + assert m.num_cmwavex_freqs == len(idxs) + 3 + + m.components["CMWaveX"].add_cmwavex_components( + np.array([3]) / u.day, + cmwxsins=np.array([0]) * cmu, + cmwxcoses=np.array([0]) * cmu, + ) + assert m.num_cmwavex_freqs == len(idxs) + 4 + + with pytest.raises(ValueError): + m.components["CMWaveX"].add_cmwavex_components( + [2 / u.day, 3 / u.day], cmwxsins=[0, 0], cmwxcoses=[0, 0, 0] + ) + + with pytest.raises(ValueError): + m.components["CMWaveX"].add_cmwavex_components( + [2 / u.day, 3 / u.day], cmwxsins=[0, 0, 0], cmwxcoses=[0, 0] + ) + + with pytest.raises(ValueError): + m.components["CMWaveX"].add_cmwavex_components( + [2 / u.day, 3 / u.day], + cmwxsins=[0, 0], + cmwxcoses=[0, 0], + frozens=[False, False, False], + ) From 19c35214ad9df346fd6e6009d4e3e9b516e99025 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Thu, 11 Jul 2024 13:07:37 +0200 Subject: [PATCH 06/31] CHANGELOG --- CHANGELOG-unreleased.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG-unreleased.md b/CHANGELOG-unreleased.md index 2899b99f8..8e7d1fc3a 100644 --- a/CHANGELOG-unreleased.md +++ b/CHANGELOG-unreleased.md @@ -10,5 +10,8 @@ the released changes. ## Unreleased ### Changed ### Added +- Fourier series representation of chromatic noise (`CMWaveX`) +- `pint.utils.cmwavex_setup` function ### Fixed +- Bug in `DMWaveX.get_indices()` function ### Removed From 5ddf707ba7fa20e2b1b3e895af9e1d0757bb1137 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Thu, 11 Jul 2024 13:17:58 +0200 Subject: [PATCH 07/31] validation for correlated noise components --- src/pint/models/timing_model.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py index 52c7cd6f1..469b26db5 100644 --- a/src/pint/models/timing_model.py +++ b/src/pint/models/timing_model.py @@ -501,23 +501,36 @@ def num_components_of_type(type): num_components_of_type(SolarWindDispersionBase) <= 1 ), "Model can have at most one solar wind dispersion component." - from pint.models.dispersion_model import DispersionDMX + from pint.models.dispersion_model import DispersionDM, DispersionDMX + from pint.models.chromatic_model import ChromaticCM from pint.models.wave import Wave from pint.models.wavex import WaveX from pint.models.dmwavex import DMWaveX + from pint.models.cmwavex import CMWaveX from pint.models.noise_model import PLRedNoise, PLDMNoise + from pint.models.ifunc import IFunc if num_components_of_type((DispersionDMX, PLDMNoise, DMWaveX)) > 1: log.warning( "DispersionDMX, PLDMNoise, and DMWaveX cannot be used together. " "They are ways of modelling the same effect." ) - if num_components_of_type((Wave, WaveX, PLRedNoise)) > 1: + if num_components_of_type((Wave, WaveX, PLRedNoise, IFunc)) > 1: log.warning( "Wave, WaveX, and PLRedNoise cannot be used together. " "They are ways of modelling the same effect." ) + if num_components_of_type((PLDMNoise, DMWaveX)) == 1: + assert ( + num_components_of_type(DispersionDM) == 1 + ), "PLDMNoise / DMWaveX component cannot be used without the DispersionDM component." + + if num_components_of_type((CMWaveX)) == 1: + assert ( + num_components_of_type(ChromaticCM) == 1 + ), "PLChromNoise / CMWaveX component cannot be used without the ChromaticCM component." + # def __str__(self): # result = "" # comps = self.components From 75a5cf564acc8537bc4c2306491274c4a900bba1 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Thu, 11 Jul 2024 13:19:27 +0200 Subject: [PATCH 08/31] changelog# --- CHANGELOG-unreleased.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG-unreleased.md b/CHANGELOG-unreleased.md index 8e7d1fc3a..7b6f00ddd 100644 --- a/CHANGELOG-unreleased.md +++ b/CHANGELOG-unreleased.md @@ -12,6 +12,7 @@ the released changes. ### Added - Fourier series representation of chromatic noise (`CMWaveX`) - `pint.utils.cmwavex_setup` function +- More validation for correlated noise components in `TimingModel` ### Fixed - Bug in `DMWaveX.get_indices()` function ### Removed From 9c97955e716045975fc92f204c583a436dba6d1b Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Thu, 11 Jul 2024 13:29:39 +0200 Subject: [PATCH 09/31] sourcery --- src/pint/models/timing_model.py | 249 +++++++++++++------------------- 1 file changed, 100 insertions(+), 149 deletions(-) diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py index 469b26db5..2160de9b2 100644 --- a/src/pint/models/timing_model.py +++ b/src/pint/models/timing_model.py @@ -2708,16 +2708,12 @@ def use_aliases(self, reset_to_default=True, alias_translation=None): if reset_to_default: po.use_alias = None if alias_translation is not None: - if hasattr(po, "origin_name"): - try: - po.use_alias = alias_translation[po.origin_name] - except KeyError: - pass - else: - try: - po.use_alias = alias_translation[p] - except KeyError: - pass + with contextlib.suppress(KeyError): + po.use_alias = ( + alias_translation[po.origin_name] + if hasattr(po, "origin_name") + else alias_translation[p] + ) def as_parfile( self, @@ -2746,7 +2742,7 @@ def as_parfile( Parfile output format. PINT outputs in 'tempo', 'tempo2' and 'pint' formats. The defaul format is `pint`. """ - if not format.lower() in _parfile_formats: + if format.lower() not in _parfile_formats: raise ValueError(f"parfile format must be one of {_parfile_formats}") self.validate() @@ -2768,33 +2764,30 @@ def as_parfile( continue result_begin += getattr(self, p).as_parfile_line(format=format) for cat in start_order: - if cat in list(cates_comp.keys()): - # print("Starting: %s" % cat) - cp = cates_comp[cat] - for cpp in cp: - result_begin += cpp.print_par(format=format) - printed_cate.append(cat) - else: + if cat not in list(cates_comp.keys()): continue + # print("Starting: %s" % cat) + cp = cates_comp[cat] + for cpp in cp: + result_begin += cpp.print_par(format=format) + printed_cate.append(cat) for cat in last_order: - if cat in list(cates_comp.keys()): - # print("Ending: %s" % cat) - cp = cates_comp[cat] - for cpp in cp: - result_end += cpp.print_par(format=format) - printed_cate.append(cat) - else: + if cat not in list(cates_comp.keys()): continue + # print("Ending: %s" % cat) + cp = cates_comp[cat] + for cpp in cp: + result_end += cpp.print_par(format=format) + printed_cate.append(cat) for cat in list(cates_comp.keys()): if cat in printed_cate: continue - else: - cp = cates_comp[cat] - for cpp in cp: - result_middle += cpp.print_par(format=format) - printed_cate.append(cat) + cp = cates_comp[cat] + for cpp in cp: + result_middle += cpp.print_par(format=format) + printed_cate.append(cat) return result_begin + result_middle + result_end @@ -2934,8 +2927,7 @@ def __len__(self): return len(self.params) def __iter__(self): - for p in self.params: - yield p + yield from self.params def as_ECL(self, epoch=None, ecl="IERS2010"): """Return TimingModel in PulsarEcliptic frame. @@ -3237,9 +3229,7 @@ def get_derived_params(self, rms=None, ntoas=None, returndict=False): ) s += f"Pulsar mass (Shapiro Delay) = {psrmass}" outdict["Mp (Msun)"] = psrmass - if not returndict: - return s - return s, outdict + return (s, outdict) if returndict else s class ModelMeta(abc.ABCMeta): @@ -3253,8 +3243,8 @@ class ModelMeta(abc.ABCMeta): """ def __init__(cls, name, bases, dct): - regname = "component_types" if "register" in dct and cls.register: + regname = "component_types" getattr(cls, regname)[name] = cls super().__init__(name, bases, dct) @@ -3334,10 +3324,10 @@ def param_prefixs(self): for p in self.params: par = getattr(self, p) if par.is_prefix: - if par.prefix not in prefixs.keys(): - prefixs[par.prefix] = [p] - else: + if par.prefix in prefixs: prefixs[par.prefix].append(p) + else: + prefixs[par.prefix] = [p] return prefixs @property_exists @@ -3407,10 +3397,7 @@ def add_param(self, param, deriv_func=None, setup=False): exist_par = getattr(self, param.name) if exist_par.value is not None: raise ValueError( - "Tried to add a second parameter called {}. " - "Old value: {} New value: {}".format( - param.name, getattr(self, param.name), param - ) + f"Tried to add a second parameter called {param.name}. Old value: {getattr(self, param.name)} New value: {param}" ) else: setattr(self, param.name, param) @@ -3433,10 +3420,7 @@ def remove_param(self, param): param : str or pint.models.Parameter The parameter to remove. """ - if isinstance(param, str): - param_name = param - else: - param_name = param.name + param_name = param if isinstance(param, str) else param.name if param_name not in self.params: raise ValueError( f"Tried to remove parameter {param_name} but it is not listed: {self.params}" @@ -3473,10 +3457,7 @@ def get_params_of_type(self, param_type): par = getattr(self, p) par_type = type(par).__name__ par_prefix = par_type[:-9] - if ( - param_type.upper() == par_type.upper() - or param_type.upper() == par_prefix.upper() - ): + if param_type.upper() in [par_type.upper(), par_prefix.upper()]: result.append(par.name) return result @@ -3521,37 +3502,35 @@ def match_param_aliases(self, alias): # Split the alias prefix, see if it is a perfix alias try: prefix, idx_str, idx = split_prefixed_name(alias) - except PrefixError: # Not a prefixed name - if pname is not None: - par = getattr(self, pname) - if par.is_prefix: - raise UnknownParameter( - f"Prefix {alias} maps to mulitple parameters" - ". Please specify the index as well." - ) - else: + except PrefixError as e: # Not a prefixed name + if pname is None: # Not a prefix, not an alias - raise UnknownParameter(f"Unknown parameter name or alias {alias}") - # When the alias is a prefixed name but not in the parameter list yet - if pname is None: - prefix_pname = self.aliases_map.get(prefix, None) - if prefix_pname: - par = getattr(self, prefix_pname) - if par.is_prefix: - raise UnknownParameter( - f"Found a similar prefixed parameter '{prefix_pname}'" - f" But parameter {par.prefix}{idx} need to be added" - f" to the model." - ) - else: - raise UnknownParameter( - f"{par} is not a prefixed parameter, howere the input" - f" {alias} has index with it." - ) + raise UnknownParameter( + f"Unknown parameter name or alias {alias}" + ) from e + par = getattr(self, pname) + if par.is_prefix: + raise UnknownParameter( + f"Prefix {alias} maps to mulitple parameters" + ". Please specify the index as well." + ) from e + if pname is not None: + return pname + if prefix_pname := self.aliases_map.get(prefix, None): + par = getattr(self, prefix_pname) + if par.is_prefix: + raise UnknownParameter( + f"Found a similar prefixed parameter '{prefix_pname}'" + f" But parameter {par.prefix}{idx} need to be added" + f" to the model." + ) else: - raise UnknownParameter(f"Unknown parameter name or alias {alias}") + raise UnknownParameter( + f"{par} is not a prefixed parameter, howere the input" + f" {alias} has index with it." + ) else: - return pname + raise UnknownParameter(f"Unknown parameter name or alias {alias}") def register_deriv_funcs(self, func, param): """Register the derivative function in to the deriv_func dictionaries. @@ -3568,15 +3547,10 @@ def register_deriv_funcs(self, func, param): if pn not in list(self.deriv_funcs.keys()): self.deriv_funcs[pn] = [func] + elif func in self.deriv_funcs[pn]: + return else: - # TODO: - # Runing setup() mulitple times can lead to adding derivative - # function multiple times. This prevent it from happening now. But - # in the future, we should think a better way to do so. - if func in self.deriv_funcs[pn]: - return - else: - self.deriv_funcs[pn] += [func] + self.deriv_funcs[pn] += [func] def is_in_parfile(self, para_dict): """Check if this subclass included in parfile. @@ -3594,11 +3568,7 @@ def is_in_parfile(self, para_dict): """ if self.component_special_params: - for p in self.component_special_params: - if p in para_dict: - return True - return False - + return any(p in para_dict for p in self.component_special_params) pNames_inpar = list(para_dict.keys()) pNames_inModel = self.params @@ -3606,22 +3576,14 @@ def is_in_parfile(self, para_dict): # should go in them. # For solar system Shapiro delay component if hasattr(self, "PLANET_SHAPIRO"): - if "NO_SS_SHAPIRO" in pNames_inpar: - return False - else: - return True - + return "NO_SS_SHAPIRO" not in pNames_inpar try: bmn = getattr(self, "binary_model_name") except AttributeError: # This isn't a binary model, keep looking pass else: - if "BINARY" in para_dict: - return bmn == para_dict["BINARY"][0] - else: - return False - + return bmn == para_dict["BINARY"][0] if "BINARY" in para_dict else False # Compare the componets parameter names with par file parameters compr = list(set(pNames_inpar).intersection(pNames_inModel)) @@ -3654,10 +3616,9 @@ def print_par(self, format="pint"): ------- str : formatted line for par file """ - result = "" - for p in self.params: - result += getattr(self, p).as_parfile_line(format=format) - return result + return "".join( + getattr(self, p).as_parfile_line(format=format) for p in self.params + ) class DelayComponent(Component): @@ -3797,7 +3758,7 @@ def _param_unit_map(self): units = {} for k, cp in self.components.items(): for p in cp.params: - if p in units.keys() and units[p] != getattr(cp, p).units: + if p in units and units[p] != getattr(cp, p).units: raise TimingModelError( f"Units of parameter '{p}' in component '{cp}' ({getattr(cp, p).units}) do not match those of existing parameter ({units[p]})" ) @@ -3815,11 +3776,9 @@ def repeatable_param(self): for p in cp.params: par = getattr(cp, p) if par.repeatable: - repeatable.append(p) - repeatable.append(par._parfile_name) + repeatable.extend((p, par._parfile_name)) # also add the aliases to the repeatable param - for als in par.aliases: - repeatable.append(als) + repeatable.extend(iter(par.aliases)) return set(repeatable) @lazyproperty @@ -3849,10 +3808,7 @@ def component_category_map(self): The mapping from components to its categore. The key is the component name and the value is the component's category name. """ - cp_ca = {} - for k, cp in self.components.items(): - cp_ca[k] = cp.category - return cp_ca + return {k: cp.category for k, cp in self.components.items()} @lazyproperty def component_unique_params(self): @@ -3894,38 +3850,35 @@ def search_binary_components(self, system_name): model. """ all_systems = self.category_component_map["pulsar_system"] - # Search the system name first if system_name in all_systems: return self.components[system_name] - else: # search for the pulsar system aliases - for cp_name in all_systems: - if system_name == self.components[cp_name].binary_model_name: - return self.components[cp_name] - - if system_name == "BTX": - raise UnknownBinaryModel( - "`BINARY BTX` is not supported bt PINT. Use " - "`BINARY BT` instead. It supports both orbital " - "period (PB, PBDOT) and orbital frequency (FB0, ...) " - "parametrizations." - ) - elif system_name == "DDFWHE": - raise UnknownBinaryModel( - "`BINARY DDFWHE` is not supported, but the same model " - "is available as `BINARY DDH`." - ) - elif system_name in ["MSS", "EH", "H88", "DDT", "BT1P", "BT2P"]: - # Binary model list taken from - # https://tempo.sourceforge.net/ref_man_sections/binary.txt - raise UnknownBinaryModel( - f"`The binary model {system_name} is not yet implemented." - ) + for cp_name in all_systems: + if system_name == self.components[cp_name].binary_model_name: + return self.components[cp_name] + if system_name == "BTX": + raise UnknownBinaryModel( + "`BINARY BTX` is not supported bt PINT. Use " + "`BINARY BT` instead. It supports both orbital " + "period (PB, PBDOT) and orbital frequency (FB0, ...) " + "parametrizations." + ) + elif system_name == "DDFWHE": + raise UnknownBinaryModel( + "`BINARY DDFWHE` is not supported, but the same model " + "is available as `BINARY DDH`." + ) + elif system_name in ["MSS", "EH", "H88", "DDT", "BT1P", "BT2P"]: + # Binary model list taken from + # https://tempo.sourceforge.net/ref_man_sections/binary.txt raise UnknownBinaryModel( - f"Pulsar system/Binary model component" - f" {system_name} is not provided." + f"`The binary model {system_name} is not yet implemented." ) + raise UnknownBinaryModel( + f"Pulsar system/Binary model component" f" {system_name} is not provided." + ) + def alias_to_pint_param(self, alias): """Translate a alias to a PINT parameter name. @@ -3989,14 +3942,11 @@ def alias_to_pint_param(self, alias): # count length of idx_str and dectect leading zeros # TODO fix the case for searching `DMX` num_lzero = len(idx_str) - len(str(idx)) - if num_lzero > 0: # Has leading zero - fmt = len(idx_str) - else: - fmt = 0 + fmt = len(idx_str) if num_lzero > 0 else 0 first_init_par = None # Handle the case of start index from 0 and 1 for start_idx in [0, 1]: - first_init_par_alias = prefix + f"{start_idx:0{fmt}}" + first_init_par_alias = f"{prefix}{start_idx:0{fmt}}" first_init_par = self._param_alias_map.get( first_init_par_alias, None ) @@ -4006,13 +3956,14 @@ def alias_to_pint_param(self, alias): break except PrefixError: pint_par = None - else: first_init_par = pint_par + if pint_par is None: raise UnknownParameter( - "Can not find matching PINT parameter for '{}'".format(alias) + f"Can not find matching PINT parameter for '{alias}'" ) + return pint_par, first_init_par def param_to_unit(self, name): @@ -4072,7 +4023,7 @@ def __init__(self, module, param, msg=None): self.msg = msg def __str__(self): - result = self.module + "." + self.param + result = f"{self.module}.{self.param}" if self.msg is not None: result += "\n " + self.msg return result From 30c8f4bcad02de5f050fc90dddfde89b2741115f Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Thu, 11 Jul 2024 13:37:38 +0200 Subject: [PATCH 10/31] validation --- src/pint/models/timing_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py index 2160de9b2..d381a91bc 100644 --- a/src/pint/models/timing_model.py +++ b/src/pint/models/timing_model.py @@ -507,7 +507,7 @@ def num_components_of_type(type): from pint.models.wavex import WaveX from pint.models.dmwavex import DMWaveX from pint.models.cmwavex import CMWaveX - from pint.models.noise_model import PLRedNoise, PLDMNoise + from pint.models.noise_model import PLRedNoise, PLDMNoise, PLChromNoise from pint.models.ifunc import IFunc if num_components_of_type((DispersionDMX, PLDMNoise, DMWaveX)) > 1: @@ -526,7 +526,7 @@ def num_components_of_type(type): num_components_of_type(DispersionDM) == 1 ), "PLDMNoise / DMWaveX component cannot be used without the DispersionDM component." - if num_components_of_type((CMWaveX)) == 1: + if num_components_of_type((PLChromNoise, CMWaveX)) == 1: assert ( num_components_of_type(ChromaticCM) == 1 ), "PLChromNoise / CMWaveX component cannot be used without the ChromaticCM component." From fa6a9d6a0dfd793bcfec8d7aed07b233c482b8a3 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Wed, 7 Aug 2024 13:54:40 +0200 Subject: [PATCH 11/31] -- --- src/pint/logging.py | 2 +- src/pint/models/binary_ddk.py | 29 +++++++++++++++++------------ 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/pint/logging.py b/src/pint/logging.py index e87a039bd..f7de1a5aa 100644 --- a/src/pint/logging.py +++ b/src/pint/logging.py @@ -129,7 +129,7 @@ class LogFilter: def __init__(self, onlyonce=None, never=None, onlyonce_level="INFO"): """ - Define regexs for messages that will only be seen once. Use ``\S+`` for a variable that might change. + Define regexs for messages that will only be seen once. Use ``\\S+`` for a variable that might change. If a message comes through with a new value for that variable, it will be seen. Make sure to escape other regex commands like ``()``. diff --git a/src/pint/models/binary_ddk.py b/src/pint/models/binary_ddk.py index 8a31f6209..452774ead 100644 --- a/src/pint/models/binary_ddk.py +++ b/src/pint/models/binary_ddk.py @@ -49,14 +49,19 @@ class BinaryDDK(BinaryDD): of the system from Earth, the finite size of the system, and the interaction of these with the proper motion. - From Kopeikin (1995) this includes :math:`\Delta_{\pi M}` (Equation 17), the mixed annual-orbital parallax term, which changes :math:`a_1` and :math:`\omega` - (:meth:`~pint.models.stand_alone_psr_binaries.DDK_model.DDKmodel.delta_a1_parallax` and :meth:`~pint.models.stand_alone_psr_binaries.DDK_model.DDKmodel.delta_omega_parallax`). + From Kopeikin (1995) this includes :math:`\\Delta_{\\pi M}` (Equation 17), + the mixed annual-orbital parallax term, which changes :math:`a_1` and :math:`\\omega` + (:meth:`~pint.models.stand_alone_psr_binaries.DDK_model.DDKmodel.delta_a1_parallax` + and :meth:`~pint.models.stand_alone_psr_binaries.DDK_model.DDKmodel.delta_omega_parallax`). - It does not include :math:`\Delta_{\pi P}`, the pure pulsar orbital parallax term (Equation 14). + It does not include :math:`\\Delta_{\\pi P}`, the pure pulsar orbital parallax term + (Equation 14). - From Kopeikin (1996) this includes apparent changes in :math:`\omega`, :math:`a_1`, and :math:`i` due to the proper motion - (:meth:`~pint.models.stand_alone_psr_binaries.DDK_model.DDKmodel.delta_omega_proper_motion`, :meth:`~pint.models.stand_alone_psr_binaries.DDK_model.DDKmodel.delta_a1_proper_motion`, - :meth:`~pint.models.stand_alone_psr_binaries.DDK_model.DDKmodel.delta_kin_proper_motion`) (Equations 8, 9, 10). + From Kopeikin (1996) this includes apparent changes in :math:`\\omega`, :math:`a_1`, and + :math:`i` due to the proper motion (:meth:`~pint.models.stand_alone_psr_binaries.DDK_model.DDKmodel.delta_omega_proper_motion`, + :meth:`~pint.models.stand_alone_psr_binaries.DDK_model.DDKmodel.delta_a1_proper_motion`, + :meth:`~pint.models.stand_alone_psr_binaries.DDK_model.DDKmodel.delta_kin_proper_motion`) + (Equations 8, 9, 10). The actual calculations for this are done in :class:`pint.models.stand_alone_psr_binaries.DDK_model.DDKmodel`. @@ -67,7 +72,7 @@ class BinaryDDK(BinaryDD): KIN the inclination angle: :math:`i` KOM - the longitude of the ascending node, Kopeikin (1995) Eq 9: :math:`\Omega` + the longitude of the ascending node, Kopeikin (1995) Eq 9: :math:`\\Omega` K96 flag for Kopeikin binary model proper motion correction @@ -233,19 +238,19 @@ def alternative_solutions(self): We first define the symmetry point where a1dot is zero (in equatorial coordinates): - :math:`KOM_0 = \\tan^{-1} (\mu_{\delta} / \mu_{\\alpha})` + :math:`KOM_0 = \\tan^{-1} (\\mu_{\\delta} / \\mu_{\\alpha})` The solutions are then: :math:`(KIN, KOM)` - :math:`(KIN, 2KOM_0 - KOM - 180^{\circ})` + :math:`(KIN, 2KOM_0 - KOM - 180^{\\circ})` - :math:`(180^{\circ}-KIN, KOM+180^{\circ})` + :math:`(180^{\\circ}-KIN, KOM+180^{\\circ})` - :math:`(180^{\circ}-KIN, 2KOM_0 - KOM)` + :math:`(180^{\\circ}-KIN, 2KOM_0 - KOM)` - All values will be between 0 and :math:`360^{\circ}`. + All values will be between 0 and :math:`360^{\\circ}`. Returns ------- From 4b931c3072fb2449bc606e4b619fa70c803d9213 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Wed, 7 Aug 2024 15:40:19 +0200 Subject: [PATCH 12/31] plchromnoise_from_cmwavex --- src/pint/utils.py | 75 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 69 insertions(+), 6 deletions(-) diff --git a/src/pint/utils.py b/src/pint/utils.py index 3ef7a7783..0327e79c1 100644 --- a/src/pint/utils.py +++ b/src/pint/utils.py @@ -3132,8 +3132,9 @@ def _get_wx2pl_lnlike( from pint.models.noise_model import powerlaw from pint import DMconst - assert component_name in {"WaveX", "DMWaveX"} - prefix = "WX" if component_name == "WaveX" else "DMWX" + assert component_name in {"WaveX", "DMWaveX", "CMWaveX"} + prefix_dict = {"WaveX": "WX", "DMWaveX": "DMWX", "CMWaveX": "CMWX"} + prefix = prefix_dict[component_name] idxs = np.array(model.components[component_name].get_indices()) @@ -3145,7 +3146,7 @@ def _get_wx2pl_lnlike( assert np.allclose( np.diff(np.diff(fs)), 0 - ), "[DM]WaveX frequencies must be uniformly spaced." + ), "WaveX/DMWaveX/CMWaveX frequencies must be uniformly spaced for this conversion to work." if ignore_fyr: year_mask = np.abs(((fs - fyr) / f0)) > 0.5 @@ -3156,7 +3157,15 @@ def _get_wx2pl_lnlike( ) f0 = np.min(fs) - scaling_factor = 1 if component_name == "WaveX" else DMconst / (1400 * u.MHz) ** 2 + scaling_factor = ( + 1 + if component_name == "WaveX" + else ( + DMconst / (1400 * u.MHz) ** 2 + if component_name == "DMWaveX" + else DMconst / 1400**model.TNCHROMIDX.value + ) + ) a = np.array( [ @@ -3185,14 +3194,14 @@ def _get_wx2pl_lnlike( def powl_model(params: Tuple[float, float]) -> float: """Get the powerlaw spectrum for the WaveX frequencies for a given - set of parameters. This calls the powerlaw function used by `PLRedNoise`/`PLDMNoise`. + set of parameters. This calls the powerlaw function used by `PLRedNoise`/`PLDMNoise`/`PLChromNoise`. """ gamma, log10_A = params return (powerlaw(fs, A=10**log10_A, gamma=gamma) * f0) ** 0.5 def mlnlike(params: Tuple[float, ...]) -> float: """Negative of the likelihood function that acts on the - `[DM]WaveX` amplitudes.""" + `[DM/CM]WaveX` amplitudes.""" sigma = powl_model(params) return 0.5 * np.sum( (a**2 / (sigma**2 + da**2)) @@ -3308,6 +3317,60 @@ def pldmnoise_from_dmwavex( return model1 +def plchromnoise_from_cmwavex( + model: "pint.models.TimingModel", ignore_fyr: bool = False +) -> "pint.models.TimingModel": + """Convert a `CMWaveX` representation of red noise to a `PLChromNoise` + representation. This is done by minimizing a likelihood function + that acts on the `CMWaveX` amplitudes over the powerlaw spectral + parameters. + + Parameters + ---------- + model: pint.models.timing_model.TimingModel + The timing model with a `CMWaveX` component. + + Returns + ------- + pint.models.timing_model.TimingModel + The timing model with a converted `PLChromNoise` component. + """ + from pint.models.noise_model import PLChromNoise + + mlnlike = _get_wx2pl_lnlike(model, "CMWaveX", ignore_fyr=ignore_fyr) + + result = minimize(mlnlike, [4, -13], method="Nelder-Mead") + if not result.success: + raise ValueError("Log-likelihood maximization failed to converge.") + + gamma_val, log10_A_val = result.x + + hess = Hessian(mlnlike) + + H = hess((gamma_val, log10_A_val)) + assert np.all(np.linalg.eigvals(H) > 0), "The Hessian is not positive definite!" + + Hinv = np.linalg.pinv(H) + assert np.all( + np.linalg.eigvals(Hinv) > 0 + ), "The inverse Hessian is not positive definite!" + + gamma_err, log10_A_err = np.sqrt(np.diag(Hinv)) + + tndmc = len(model.components["CMWaveX"].get_indices()) + + model1 = deepcopy(model) + model1.remove_component("CMWaveX") + model1.add_component(PLChromNoise()) + model1.TNCHROMAMP.value = log10_A_val + model1.TNCHROMGAM.value = gamma_val + model1.TNCHROMC.value = tndmc + model1.TNCHROMAMP.uncertainty_value = log10_A_err + model1.TNCHROMGAM.uncertainty_value = gamma_err + + return model1 + + def find_optimal_nharms( model: "pint.models.TimingModel", toas: "pint.toa.TOAs", From b39395aeadd75f10c5073c234cf71a2b45bc885f Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Wed, 7 Aug 2024 15:46:39 +0200 Subject: [PATCH 13/31] tests --- tests/test_wx2pl.py | 76 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/tests/test_wx2pl.py b/tests/test_wx2pl.py index 788a238dc..8f426f2e3 100644 --- a/tests/test_wx2pl.py +++ b/tests/test_wx2pl.py @@ -3,8 +3,10 @@ from pint.simulation import make_fake_toas_uniform from pint.fitter import WLSFitter from pint.utils import ( + cmwavex_setup, dmwavex_setup, find_optimal_nharms, + plchromnoise_from_cmwavex, wavex_setup, plrednoise_from_wavex, pldmnoise_from_dmwavex, @@ -107,6 +109,54 @@ def data_dmwx(): return m, t +@pytest.fixture +def data_cmwx(): + par_sim_cmwx = """ + PSR SIM3 + RAJ 05:00:00 1 + DECJ 15:00:00 1 + PEPOCH 55000 + F0 100 1 + F1 -1e-15 1 + PHOFF 0 1 + DM 15 1 + TNCHROMIDX 4 + CM 10 + TNCHROMAMP -13 + TNCHROMGAM 3.5 + TNCHROMC 10 + TZRMJD 55000 + TZRFRQ 1400 + TZRSITE gbt + UNITS TDB + EPHEM DE440 + CLOCK TT(BIPM2019) + """ + + m = get_model(StringIO(par_sim_cmwx)) + + ntoas = 200 + toaerrs = np.random.uniform(0.5, 2.0, ntoas) * u.us + freqs = np.linspace(500, 1500, 4) * u.MHz + + t = make_fake_toas_uniform( + startMJD=54001, + endMJD=56001, + ntoas=ntoas, + model=m, + freq=freqs, + obs="gbt", + error=toaerrs, + add_noise=True, + add_correlated_noise=True, + name="fake", + include_bipm=True, + multi_freqs_in_epoch=True, + ) + + return m, t + + def test_wx2pl(data_wx): m, t = data_wx @@ -147,6 +197,32 @@ def test_dmwx2pldm(data_dmwx): assert abs(m.TNDMGAM.value - m2.TNDMGAM.value) / m2.TNDMGAM.uncertainty_value < 5 +def test_cmwx2pldm(data_cmwx): + m, t = data_cmwx + + m1 = deepcopy(m) + m1.remove_component("PLChromNoise") + + Tspan = t.get_mjds().max() - t.get_mjds().min() + cmwavex_setup(m1, Tspan, n_freqs=int(m.TNCHROMC.value), freeze_params=False) + + ftr = WLSFitter(t, m1) + ftr.fit_toas(maxiter=10) + m1 = ftr.model + + m2 = plchromnoise_from_cmwavex(m1) + + assert "PLChromNoise" in m2.components + assert ( + abs(m.TNCHROMAMP.value - m2.TNCHROMAMP.value) / m2.TNCHROMAMP.uncertainty_value + < 5 + ) + assert ( + abs(m.TNCHROMGAM.value - m2.TNCHROMGAM.value) / m2.TNCHROMGAM.uncertainty_value + < 5 + ) + + def test_find_optimal_nharms_wx(data_wx): m, t = data_wx From e5f67837b5344c8ec7585cb7fb4c9f564cc1ccd6 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Wed, 7 Aug 2024 15:47:19 +0200 Subject: [PATCH 14/31] CHANGELOG --- CHANGELOG-unreleased.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG-unreleased.md b/CHANGELOG-unreleased.md index 08e7ebd6f..d201b1d22 100644 --- a/CHANGELOG-unreleased.md +++ b/CHANGELOG-unreleased.md @@ -16,7 +16,7 @@ the released changes. - Doing `model.par = something` will try to assign to `par.quantity` or `par.value` but will give warning - `PLChromNoise` component to model chromatic red noise with a power law spectrum - Fourier series representation of chromatic noise (`CMWaveX`) -- `pint.utils.cmwavex_setup` function +- `pint.utils.cmwavex_setup` and `pint.utils.plchromnoise_from_cmwavex` functions - More validation for correlated noise components in `TimingModel` ### Fixed - Bug in `DMWaveX.get_indices()` function From e9e1771e3aa4a17368650bfe4b6486e712209551 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Thu, 8 Aug 2024 12:08:52 +0200 Subject: [PATCH 15/31] float --- src/pint/utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/pint/utils.py b/src/pint/utils.py index 0327e79c1..758307558 100644 --- a/src/pint/utils.py +++ b/src/pint/utils.py @@ -3203,11 +3203,13 @@ def mlnlike(params: Tuple[float, ...]) -> float: """Negative of the likelihood function that acts on the `[DM/CM]WaveX` amplitudes.""" sigma = powl_model(params) - return 0.5 * np.sum( - (a**2 / (sigma**2 + da**2)) - + (b**2 / (sigma**2 + db**2)) - + np.log(sigma**2 + da**2) - + np.log(sigma**2 + db**2) + return 0.5 * float( + np.sum( + (a**2 / (sigma**2 + da**2)) + + (b**2 / (sigma**2 + db**2)) + + np.log(sigma**2 + da**2) + + np.log(sigma**2 + db**2) + ) ) return mlnlike From 27c7c4eee0ce225166a167df90d69f9aac844983 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Thu, 8 Aug 2024 20:07:41 +0200 Subject: [PATCH 16/31] move functions --- src/pint/models/cmwavex.py | 98 ++++++- src/pint/models/dmwavex.py | 100 ++++++- src/pint/models/wavex.py | 364 +++++++++++++++++++++++- src/pint/utils.py | 552 +------------------------------------ tests/test_cmwavex.py | 2 +- tests/test_dmwavex.py | 2 +- tests/test_wavex.py | 8 +- tests/test_wx2pl.py | 6 +- 8 files changed, 567 insertions(+), 565 deletions(-) diff --git a/src/pint/models/cmwavex.py b/src/pint/models/cmwavex.py index 465700565..4a686036f 100644 --- a/src/pint/models/cmwavex.py +++ b/src/pint/models/cmwavex.py @@ -1,12 +1,14 @@ """Chromatic variations expressed as a sum of sinusoids.""" +from typing import Iterable, List, Optional, Union +import warnings import astropy.units as u import numpy as np from loguru import logger as log from warnings import warn from pint.models.parameter import MJDParameter, prefixParameter -from pint.models.timing_model import MissingParameter +from pint.models.timing_model import MissingParameter, TimingModel from pint.models.chromatic_model import Chromatic, cmu from pint import DMconst @@ -385,3 +387,97 @@ def d_cm_d_CMWXCOS(self, toas, param, acc_delay=None): arg = 2.0 * np.pi * freq * base_phase deriv = np.cos(arg.value) return deriv * cmu / par.units + + +def cmwavex_setup( + model: TimingModel, + T_span: Union[float, u.Quantity], + freqs: Optional[Iterable[Union[float, u.Quantity]]] = None, + n_freqs: Optional[int] = None, + freeze_params: bool = False, +) -> List[int]: + """ + Set-up a CMWaveX model based on either an array of user-provided frequencies or the wave number + frequency calculation. Sine and Cosine amplitudes are initially set to zero + + User specifies T_span and either freqs or n_freqs. This function assumes that the timing model does not already + have any CMWaveX components. See add_cmwavex_component() or add_cmwavex_components() to add components + to an existing CMWaveX model. + + Parameters + ---------- + + model : pint.models.timing_model.TimingModel + T_span : float, astropy.quantity.Quantity + Time span used to calculate nyquist frequency when using freqs + Time span used to calculate CMWaveX frequencies when using n_freqs + Usually to be set as the length of the timing baseline the model is being used for + freqs : iterable of float or astropy.quantity.Quantity, None + User inputed base frequencies + n_freqs : int, None + Number of wave frequencies to calculate using the equation: freq_n = 2 * pi * n / T_span + Where n is the wave number, and T_span is the total time span of the toas in the fitter object + freeze_params : bool, optional + Whether the new parameters should be frozen + + Returns + ------- + + indices : list + Indices that have been assigned to new WaveX components + """ + from pint.models.cmwavex import CMWaveX + + if (freqs is None) and (n_freqs is None): + raise ValueError( + "CMWaveX component base frequencies are not specified. " + "Please input either freqs or n_freqs" + ) + + if (freqs is not None) and (n_freqs is not None): + raise ValueError( + "Both freqs and n_freqs are specified. Only one or the other should be used" + ) + + if n_freqs is not None and n_freqs <= 0: + raise ValueError("Must use a non-zero number of wave frequencies") + + model.add_component(CMWaveX()) + if isinstance(T_span, u.quantity.Quantity): + T_span.to(u.d) + else: + T_span *= u.d + + nyqist_freq = 1.0 / (2.0 * T_span) + if freqs is not None: + if isinstance(freqs, u.quantity.Quantity): + freqs.to(u.d**-1) + else: + freqs *= u.d**-1 + if len(freqs) == 1: + model.CMWXFREQ_0001.quantity = freqs + else: + freqs = np.array(freqs) + freqs.sort() + if min(np.diff(freqs)) < nyqist_freq: + warnings.warn( + "CMWaveX frequency spacing is finer than frequency resolution of data" + ) + model.CMWXFREQ_0001.quantity = freqs[0] + model.components["CMWaveX"].add_cmwavex_components(freqs[1:]) + + if n_freqs is not None: + if n_freqs == 1: + wave_freq = 1 / T_span + model.CMWXFREQ_0001.quantity = wave_freq + else: + wave_numbers = np.arange(1, n_freqs + 1) + wave_freqs = wave_numbers / T_span + model.CMWXFREQ_0001.quantity = wave_freqs[0] + model.components["CMWaveX"].add_cmwavex_components(wave_freqs[1:]) + + for p in model.params: + if p.startswith("CMWXSIN") or p.startswith("CMWXCOS"): + model[p].frozen = freeze_params + + return model.components["CMWaveX"].get_indices() diff --git a/src/pint/models/dmwavex.py b/src/pint/models/dmwavex.py index 368862391..fc758680e 100644 --- a/src/pint/models/dmwavex.py +++ b/src/pint/models/dmwavex.py @@ -1,12 +1,14 @@ """DM variations expressed as a sum of sinusoids.""" +from typing import Iterable, List, Optional, Union +import warnings import astropy.units as u import numpy as np from loguru import logger as log from warnings import warn from pint.models.parameter import MJDParameter, prefixParameter -from pint.models.timing_model import MissingParameter +from pint.models.timing_model import MissingParameter, TimingModel from pint.models.dispersion_model import Dispersion from pint import DMconst, dmu @@ -22,7 +24,7 @@ class DMWaveX(Dispersion): .. paramtable:: :class: pint.models.dmwavex.DMWaveX - To set up a DMWaveX model, users can use the `pint.utils` function `dmwavex_setup()` with either a list of frequencies or a choice + To set up a DMWaveX model, users can use the function `dmwavex_setup()` with either a list of frequencies or a choice of harmonics of a base frequency determined by 2 * pi /Timespan """ @@ -385,3 +387,97 @@ def d_dm_d_DMWXCOS(self, toas, param, acc_delay=None): arg = 2.0 * np.pi * freq * base_phase deriv = np.cos(arg.value) return deriv * dmu / par.units + + +def dmwavex_setup( + model: TimingModel, + T_span: Union[float, u.Quantity], + freqs: Optional[Iterable[Union[float, u.Quantity]]] = None, + n_freqs: Optional[int] = None, + freeze_params: bool = False, +) -> List[int]: + """ + Set-up a DMWaveX model based on either an array of user-provided frequencies or the wave number + frequency calculation. Sine and Cosine amplitudes are initially set to zero + + User specifies T_span and either freqs or n_freqs. This function assumes that the timing model does not already + have any DMWaveX components. See add_dmwavex_component() or add_dmwavex_components() to add components + to an existing DMWaveX model. + + Parameters + ---------- + + model : pint.models.timing_model.TimingModel + T_span : float, astropy.quantity.Quantity + Time span used to calculate nyquist frequency when using freqs + Time span used to calculate DMWaveX frequencies when using n_freqs + Usually to be set as the length of the timing baseline the model is being used for + freqs : iterable of float or astropy.quantity.Quantity, None + User inputed base frequencies + n_freqs : int, None + Number of wave frequencies to calculate using the equation: freq_n = 2 * pi * n / T_span + Where n is the wave number, and T_span is the total time span of the toas in the fitter object + freeze_params : bool, optional + Whether the new parameters should be frozen + + Returns + ------- + + indices : list + Indices that have been assigned to new WaveX components + """ + from pint.models.dmwavex import DMWaveX + + if (freqs is None) and (n_freqs is None): + raise ValueError( + "DMWaveX component base frequencies are not specified. " + "Please input either freqs or n_freqs" + ) + + if (freqs is not None) and (n_freqs is not None): + raise ValueError( + "Both freqs and n_freqs are specified. Only one or the other should be used" + ) + + if n_freqs is not None and n_freqs <= 0: + raise ValueError("Must use a non-zero number of wave frequencies") + + model.add_component(DMWaveX()) + if isinstance(T_span, u.quantity.Quantity): + T_span.to(u.d) + else: + T_span *= u.d + + nyqist_freq = 1.0 / (2.0 * T_span) + if freqs is not None: + if isinstance(freqs, u.quantity.Quantity): + freqs.to(u.d**-1) + else: + freqs *= u.d**-1 + if len(freqs) == 1: + model.DMWXFREQ_0001.quantity = freqs + else: + freqs = np.array(freqs) + freqs.sort() + if min(np.diff(freqs)) < nyqist_freq: + warnings.warn( + "DMWaveX frequency spacing is finer than frequency resolution of data" + ) + model.DMWXFREQ_0001.quantity = freqs[0] + model.components["DMWaveX"].add_dmwavex_components(freqs[1:]) + + if n_freqs is not None: + if n_freqs == 1: + wave_freq = 1 / T_span + model.DMWXFREQ_0001.quantity = wave_freq + else: + wave_numbers = np.arange(1, n_freqs + 1) + wave_freqs = wave_numbers / T_span + model.DMWXFREQ_0001.quantity = wave_freqs[0] + model.components["DMWaveX"].add_dmwavex_components(wave_freqs[1:]) + + for p in model.params: + if p.startswith("DMWXSIN") or p.startswith("DMWXCOS"): + model[p].frozen = freeze_params + + return model.components["DMWaveX"].get_indices() diff --git a/src/pint/models/wavex.py b/src/pint/models/wavex.py index 1714f3644..f3b47a717 100644 --- a/src/pint/models/wavex.py +++ b/src/pint/models/wavex.py @@ -1,12 +1,15 @@ """Delays expressed as a sum of sinusoids.""" +from copy import deepcopy +from typing import Iterable, List, Optional, Union +import warnings import astropy.units as u import numpy as np from loguru import logger as log from warnings import warn from pint.models.parameter import MJDParameter, prefixParameter -from pint.models.timing_model import DelayComponent, MissingParameter +from pint.models.timing_model import DelayComponent, MissingParameter, TimingModel class WaveX(DelayComponent): @@ -39,7 +42,7 @@ class WaveX(DelayComponent): WARNING: If the choice of WaveX frequencies in a `TimingModel` doesn't correspond to harmonics of some base freqeuncy, it will not be possible to convert it to a Wave model. - To set up a WaveX model, users can use the `pint.utils` function `wavex_setup()` with either a list of frequencies or a choice + To set up a WaveX model, users can use the function `wavex_setup()` with either a list of frequencies or a choice of harmonics of a base frequency determined by 2 * pi /Timespan """ @@ -394,3 +397,360 @@ def d_wavex_delay_d_WXCOS(self, toas, param, delays, acc_delay=None): arg = 2.0 * np.pi * freq * base_phase deriv = np.cos(arg.value) return deriv * u.s / par.units + + +def wavex_setup( + model: TimingModel, + T_span: Union[float, u.Quantity], + freqs: Optional[Iterable[Union[float, u.Quantity]]] = None, + n_freqs: Optional[int] = None, + freeze_params: bool = False, +) -> List[int]: + """ + Set-up a WaveX model based on either an array of user-provided frequencies or the wave number + frequency calculation. Sine and Cosine amplitudes are initially set to zero + + User specifies T_span and either freqs or n_freqs. This function assumes that the timing model does not already + have any WaveX components. See add_wavex_component() or add_wavex_components() to add WaveX components + to an existing WaveX model. + + Parameters + ---------- + + model : pint.models.timing_model.TimingModel + T_span : float, astropy.quantity.Quantity + Time span used to calculate nyquist frequency when using freqs + Time span used to calculate WaveX frequencies when using n_freqs + Usually to be set as the length of the timing baseline the model is being used for + freqs : iterable of float or astropy.quantity.Quantity, None + User inputed base frequencies + n_freqs : int, None + Number of wave frequencies to calculate using the equation: freq_n = 2 * pi * n / T_span + Where n is the wave number, and T_span is the total time span of the toas in the fitter object + freeze_params : bool, optional + Whether the new parameters should be frozen + + Returns + ------- + + indices : list + Indices that have been assigned to new WaveX components + """ + from pint.models.wavex import WaveX + + if (freqs is None) and (n_freqs is None): + raise ValueError( + "WaveX component base frequencies are not specified. " + "Please input either freqs or n_freqs" + ) + + if (freqs is not None) and (n_freqs is not None): + raise ValueError( + "Both freqs and n_freqs are specified. Only one or the other should be used" + ) + + if n_freqs is not None and n_freqs <= 0: + raise ValueError("Must use a non-zero number of wave frequencies") + + model.add_component(WaveX()) + if isinstance(T_span, u.quantity.Quantity): + T_span.to(u.d) + else: + T_span *= u.d + + nyqist_freq = 1.0 / (2.0 * T_span) + if freqs is not None: + if isinstance(freqs, u.quantity.Quantity): + freqs.to(u.d**-1) + else: + freqs *= u.d**-1 + if len(freqs) == 1: + model.WXFREQ_0001.quantity = freqs + else: + freqs = np.array(freqs) + freqs.sort() + if min(np.diff(freqs)) < nyqist_freq: + warnings.warn( + "Wave frequency spacing is finer than frequency resolution of data" + ) + model.WXFREQ_0001.quantity = freqs[0] + model.components["WaveX"].add_wavex_components(freqs[1:]) + + if n_freqs is not None: + if n_freqs == 1: + wave_freq = 1 / T_span + model.WXFREQ_0001.quantity = wave_freq + else: + wave_numbers = np.arange(1, n_freqs + 1) + wave_freqs = wave_numbers / T_span + model.WXFREQ_0001.quantity = wave_freqs[0] + model.components["WaveX"].add_wavex_components(wave_freqs[1:]) + + for p in model.params: + if p.startswith("WXSIN") or p.startswith("WXCOS"): + model[p].frozen = freeze_params + + return model.components["WaveX"].get_indices() + + +def get_wavex_freqs( + model: TimingModel, + index: Optional[Union[float, int, List, np.ndarray]] = None, + quantity: bool = False, +) -> List[Union[float, u.Quantity]]: + """ + Return the WaveX frequencies for a timing model. + + If index is specified, returns the frequencies corresponding to the user-provided indices. + If index isn't specified, returns all WaveX frequencies in timing model + + Parameters + ---------- + model : pint.models.timing_model.TimingModel + Timing model from which to return WaveX frequencies + index : float, int, list, np.ndarray, None + Number or list/array of numbers corresponding to WaveX frequencies to return + quantity : bool + If set to True, returns a list of astropy.quanitity.Quantity rather than a list of prefixParameters + + Returns + ------- + List of WXFREQ_ parameters + """ + if index is None: + freqs = model.components["WaveX"].get_prefix_mapping_component("WXFREQ_") + if len(freqs) == 1: + values = getattr(model.components["WaveX"], freqs.values()) + else: + values = [ + getattr(model.components["WaveX"], param) for param in freqs.values() + ] + elif isinstance(index, (int, float, np.int64)): + idx_rf = f"{int(index):04d}" + values = getattr(model.components["WaveX"], f"WXFREQ_{idx_rf}") + elif isinstance(index, (list, set, np.ndarray)): + idx_rf = [f"{int(idx):04d}" for idx in index] + values = [getattr(model.components["WaveX"], f"WXFREQ_{ind}") for ind in idx_rf] + else: + raise TypeError( + f"index most be a float, int, set, list, array, or None - not {type(index)}" + ) + if quantity: + if len(values) == 1: + values = [values[0].quantity] + else: + values = [v.quantity for v in values] + return values + + +def get_wavex_amps( + model: TimingModel, + index: Optional[Union[float, int, List, np.ndarray]] = None, + quantity: bool = False, +) -> List[Union[float, u.Quantity]]: + """ + Return the WaveX amplitudes for a timing model. + + If index is specified, returns the sine/cosine amplitudes corresponding to the user-provided indices. + If index isn't specified, returns all WaveX sine/cosine amplitudes in timing model + + Parameters + ---------- + model : pint.models.timing_model.TimingModel + Timing model from which to return WaveX frequencies + index : float, int, list, np.ndarray, None + Number or list/array of numbers corresponding to WaveX amplitudes to return + quantity : bool + If set to True, returns a list of tuples of astropy.quanitity.Quantity rather than a list of prefixParameters tuples + + Returns + ------- + List of WXSIN_ and WXCOS_ parameters + """ + if index is None: + indices = ( + model.components["WaveX"].get_prefix_mapping_component("WXSIN_").keys() + ) + if len(indices) == 1: + values = getattr( + model.components["WaveX"], f"WXSIN_{int(indices):04d}" + ), getattr(model.components["WaveX"], f"WXCOS_{int(indices):04d}") + else: + values = [ + ( + getattr(model.components["WaveX"], f"WXSIN_{int(idx):04d}"), + getattr(model.components["WaveX"], f"WXCOS_{int(idx):04d}"), + ) + for idx in indices + ] + elif isinstance(index, (int, float, np.int64)): + idx_rf = f"{int(index):04d}" + values = getattr(model.components["WaveX"], f"WXSIN_{idx_rf}"), getattr( + model.components["WaveX"], f"WXCOS_{idx_rf}" + ) + elif isinstance(index, (list, set, np.ndarray)): + idx_rf = [f"{int(idx):04d}" for idx in index] + values = [ + ( + getattr(model.components["WaveX"], f"WXSIN_{ind}"), + getattr(model.components["WaveX"], f"WXCOS_{ind}"), + ) + for ind in idx_rf + ] + else: + raise TypeError( + f"index most be a float, int, set, list, array, or None - not {type(index)}" + ) + if quantity: + if isinstance(values, tuple): + values = tuple(v.quantity for v in values) + if isinstance(values, list): + values = [(v[0].quantity, v[1].quantity) for v in values] + return values + + +def translate_wave_to_wavex( + model: TimingModel, +) -> TimingModel: + """ + Go from a Wave model to a WaveX model + + WaveX frequencies get calculated based on the Wave model WAVEOM parameter and the number of WAVE parameters. + WXFREQ_000k = [WAVEOM * (k+1)] / [2 * pi] + + WaveX amplitudes are taken from the WAVE pair parameters + + Paramters + --------- + model : pint.models.timing_model.TimingModel + TimingModel containing a Wave model to be converted to a WaveX model + + Returns + ------- + pint.models.timing_model.TimingModel + New timing model with converted WaveX model included + """ + from pint.models.wavex import WaveX + + new_model = deepcopy(model) + wave_names = [ + f"WAVE{ii}" for ii in range(1, model.components["Wave"].num_wave_terms + 1) + ] + wave_terms = [getattr(model.components["Wave"], name) for name in wave_names] + wave_om = model.components["Wave"].WAVE_OM.quantity + wave_epoch = model.components["Wave"].WAVEEPOCH.quantity + new_model.remove_component("Wave") + new_model.add_component(WaveX()) + new_model.WXEPOCH.value = wave_epoch.value + for k, wave_term in enumerate(wave_terms): + wave_sin_amp, wave_cos_amp = wave_term.quantity + wavex_freq = _translate_wave_freqs(wave_om, k) + if k == 0: + new_model.WXFREQ_0001.value = wavex_freq.value + new_model.WXSIN_0001.value = -wave_sin_amp.value + new_model.WXCOS_0001.value = -wave_cos_amp.value + else: + new_model.components["WaveX"].add_wavex_component( + wavex_freq, wxsin=-wave_sin_amp, wxcos=-wave_cos_amp + ) + return new_model + + +def translate_wavex_to_wave( + model: TimingModel, +) -> TimingModel: + """ + Go from a WaveX timing model to a Wave timing model. + WARNING: Not every WaveX model can be appropriately translated into a Wave model. This is dependent on the user's choice of frequencies in the WaveX model. + In order for a WaveX model to be able to be converted into a Wave model, every WaveX frequency must produce the same value of WAVEOM in the calculation: + + WAVEOM = [2 * pi * WXFREQ_000k] / (k + 1) + Paramters + --------- + model : pint.models.timing_model.TimingModel + TimingModel containing a WaveX model to be converted to a Wave model + + Returns + ------- + pint.models.timing_model.TimingModel + New timing model with converted Wave model included + """ + from pint.models.wave import Wave + + new_model = deepcopy(model) + indices = model.components["WaveX"].get_indices() + wxfreqs = get_wavex_freqs(model, indices, quantity=True) + wave_om = _translate_wavex_freqs(wxfreqs, (indices - 1)) + if wave_om == False: + raise ValueError( + "This WaveX model cannot be properly translated into a Wave model due to the WaveX frequencies not producing a consistent WAVEOM value" + ) + wave_amps = get_wavex_amps(model, index=indices, quantity=True) + new_model.remove_component("WaveX") + new_model.add_component(Wave()) + new_model.WAVEEPOCH.quantity = model.WXEPOCH.quantity + new_model.WAVE_OM.quantity = wave_om + new_model.WAVE1.quantity = tuple(w * -1.0 for w in wave_amps[0]) + if len(indices) > 1: + for i in range(1, len(indices)): + print(wave_amps[i]) + wave_amps[i] = tuple(w * -1.0 for w in wave_amps[i]) + new_model.components["Wave"].add_wave_component( + wave_amps[i], index=indices[i] + ) + return new_model + + +def _translate_wavex_freqs(wxfreq: Union[float, u.Quantity], k: int) -> u.Quantity: + """ + Use WaveX model WXFREQ_ parameters and wave number k to calculate the Wave model WAVEOM frequency parameter. + + Parameters + ---------- + wxfreq : float or astropy.quantity.Quantity + WaveX frequency from which the WAVEOM parameter will be calculated + If float is given default units of 1/d assigned + k : int + wave number to use to calculate Wave WAVEOM parameter + + Returns + ------- + astropy.units.Quantity + WAVEOM quantity in units 1/d that can be used in Wave model + """ + if isinstance(wxfreq, u.quantity.Quantity): + wxfreq.to(u.d**-1) + else: + wxfreq *= u.d**-1 + if len(wxfreq) == 1: + return (2.0 * np.pi * wxfreq) / (k + 1.0) + wave_om = [((2.0 * np.pi * wxfreq[i]) / (k[i] + 1.0)) for i in range(len(wxfreq))] + return ( + sum(wave_om) / len(wave_om) + if np.allclose(wave_om, wave_om[0], atol=1e-3) + else False + ) + + +def _translate_wave_freqs(om: Union[float, u.Quantity], k: int) -> u.Quantity: + """ + Use Wave model WAVEOM parameter to calculate a WaveX WXFREQ_ frequency parameter for wave number k + + Parameters + ---------- + om : float or astropy.quantity.Quantity + Base frequency of Wave model solution - parameter WAVEOM + If float is given default units of 1/d assigned + k : int + wave number to use to calculate WaveX WXFREQ_ frequency parameter + + Returns + ------- + astropy.units.Quantity + WXFREQ_ quantity in units 1/d that can be used in WaveX model + """ + if isinstance(om, u.quantity.Quantity): + om.to(u.d**-1) + else: + om *= u.d**-1 + return (om * (k + 1)) / (2.0 * np.pi) diff --git a/src/pint/utils.py b/src/pint/utils.py index 758307558..883ee3dc0 100644 --- a/src/pint/utils.py +++ b/src/pint/utils.py @@ -108,11 +108,6 @@ "merge_dmx", "split_dmx", "split_swx", - "wavex_setup", - "translate_wave_to_wavex", - "get_wavex_freqs", - "get_wavex_amps", - "translate_wavex_to_wave", "weighted_mean", "ELL1_check", "FTest", @@ -1454,551 +1449,6 @@ def split_swx(model: "pint.models.TimingModel", time: Time) -> Tuple[int, int]: return index, newindex -def wavex_setup( - model: "pint.models.TimingModel", - T_span: Union[float, u.Quantity], - freqs: Optional[Iterable[Union[float, u.Quantity]]] = None, - n_freqs: Optional[int] = None, - freeze_params: bool = False, -) -> List[int]: - """ - Set-up a WaveX model based on either an array of user-provided frequencies or the wave number - frequency calculation. Sine and Cosine amplitudes are initially set to zero - - User specifies T_span and either freqs or n_freqs. This function assumes that the timing model does not already - have any WaveX components. See add_wavex_component() or add_wavex_components() to add WaveX components - to an existing WaveX model. - - Parameters - ---------- - - model : pint.models.timing_model.TimingModel - T_span : float, astropy.quantity.Quantity - Time span used to calculate nyquist frequency when using freqs - Time span used to calculate WaveX frequencies when using n_freqs - Usually to be set as the length of the timing baseline the model is being used for - freqs : iterable of float or astropy.quantity.Quantity, None - User inputed base frequencies - n_freqs : int, None - Number of wave frequencies to calculate using the equation: freq_n = 2 * pi * n / T_span - Where n is the wave number, and T_span is the total time span of the toas in the fitter object - freeze_params : bool, optional - Whether the new parameters should be frozen - - Returns - ------- - - indices : list - Indices that have been assigned to new WaveX components - """ - from pint.models.wavex import WaveX - - if (freqs is None) and (n_freqs is None): - raise ValueError( - "WaveX component base frequencies are not specified. " - "Please input either freqs or n_freqs" - ) - - if (freqs is not None) and (n_freqs is not None): - raise ValueError( - "Both freqs and n_freqs are specified. Only one or the other should be used" - ) - - if n_freqs is not None and n_freqs <= 0: - raise ValueError("Must use a non-zero number of wave frequencies") - - model.add_component(WaveX()) - if isinstance(T_span, u.quantity.Quantity): - T_span.to(u.d) - else: - T_span *= u.d - - nyqist_freq = 1.0 / (2.0 * T_span) - if freqs is not None: - if isinstance(freqs, u.quantity.Quantity): - freqs.to(u.d**-1) - else: - freqs *= u.d**-1 - if len(freqs) == 1: - model.WXFREQ_0001.quantity = freqs - else: - freqs = np.array(freqs) - freqs.sort() - if min(np.diff(freqs)) < nyqist_freq: - warnings.warn( - "Wave frequency spacing is finer than frequency resolution of data" - ) - model.WXFREQ_0001.quantity = freqs[0] - model.components["WaveX"].add_wavex_components(freqs[1:]) - - if n_freqs is not None: - if n_freqs == 1: - wave_freq = 1 / T_span - model.WXFREQ_0001.quantity = wave_freq - else: - wave_numbers = np.arange(1, n_freqs + 1) - wave_freqs = wave_numbers / T_span - model.WXFREQ_0001.quantity = wave_freqs[0] - model.components["WaveX"].add_wavex_components(wave_freqs[1:]) - - for p in model.params: - if p.startswith("WXSIN") or p.startswith("WXCOS"): - model[p].frozen = freeze_params - - return model.components["WaveX"].get_indices() - - -def dmwavex_setup( - model: "pint.models.TimingModel", - T_span: Union[float, u.Quantity], - freqs: Optional[Iterable[Union[float, u.Quantity]]] = None, - n_freqs: Optional[int] = None, - freeze_params: bool = False, -) -> List[int]: - """ - Set-up a DMWaveX model based on either an array of user-provided frequencies or the wave number - frequency calculation. Sine and Cosine amplitudes are initially set to zero - - User specifies T_span and either freqs or n_freqs. This function assumes that the timing model does not already - have any DMWaveX components. See add_dmwavex_component() or add_dmwavex_components() to add components - to an existing DMWaveX model. - - Parameters - ---------- - - model : pint.models.timing_model.TimingModel - T_span : float, astropy.quantity.Quantity - Time span used to calculate nyquist frequency when using freqs - Time span used to calculate DMWaveX frequencies when using n_freqs - Usually to be set as the length of the timing baseline the model is being used for - freqs : iterable of float or astropy.quantity.Quantity, None - User inputed base frequencies - n_freqs : int, None - Number of wave frequencies to calculate using the equation: freq_n = 2 * pi * n / T_span - Where n is the wave number, and T_span is the total time span of the toas in the fitter object - freeze_params : bool, optional - Whether the new parameters should be frozen - - Returns - ------- - - indices : list - Indices that have been assigned to new WaveX components - """ - from pint.models.dmwavex import DMWaveX - - if (freqs is None) and (n_freqs is None): - raise ValueError( - "DMWaveX component base frequencies are not specified. " - "Please input either freqs or n_freqs" - ) - - if (freqs is not None) and (n_freqs is not None): - raise ValueError( - "Both freqs and n_freqs are specified. Only one or the other should be used" - ) - - if n_freqs is not None and n_freqs <= 0: - raise ValueError("Must use a non-zero number of wave frequencies") - - model.add_component(DMWaveX()) - if isinstance(T_span, u.quantity.Quantity): - T_span.to(u.d) - else: - T_span *= u.d - - nyqist_freq = 1.0 / (2.0 * T_span) - if freqs is not None: - if isinstance(freqs, u.quantity.Quantity): - freqs.to(u.d**-1) - else: - freqs *= u.d**-1 - if len(freqs) == 1: - model.DMWXFREQ_0001.quantity = freqs - else: - freqs = np.array(freqs) - freqs.sort() - if min(np.diff(freqs)) < nyqist_freq: - warnings.warn( - "DMWaveX frequency spacing is finer than frequency resolution of data" - ) - model.DMWXFREQ_0001.quantity = freqs[0] - model.components["DMWaveX"].add_dmwavex_components(freqs[1:]) - - if n_freqs is not None: - if n_freqs == 1: - wave_freq = 1 / T_span - model.DMWXFREQ_0001.quantity = wave_freq - else: - wave_numbers = np.arange(1, n_freqs + 1) - wave_freqs = wave_numbers / T_span - model.DMWXFREQ_0001.quantity = wave_freqs[0] - model.components["DMWaveX"].add_dmwavex_components(wave_freqs[1:]) - - for p in model.params: - if p.startswith("DMWXSIN") or p.startswith("DMWXCOS"): - model[p].frozen = freeze_params - - return model.components["DMWaveX"].get_indices() - - -def cmwavex_setup( - model: "pint.models.TimingModel", - T_span: Union[float, u.Quantity], - freqs: Optional[Iterable[Union[float, u.Quantity]]] = None, - n_freqs: Optional[int] = None, - freeze_params: bool = False, -) -> List[int]: - """ - Set-up a CMWaveX model based on either an array of user-provided frequencies or the wave number - frequency calculation. Sine and Cosine amplitudes are initially set to zero - - User specifies T_span and either freqs or n_freqs. This function assumes that the timing model does not already - have any CMWaveX components. See add_cmwavex_component() or add_cmwavex_components() to add components - to an existing CMWaveX model. - - Parameters - ---------- - - model : pint.models.timing_model.TimingModel - T_span : float, astropy.quantity.Quantity - Time span used to calculate nyquist frequency when using freqs - Time span used to calculate CMWaveX frequencies when using n_freqs - Usually to be set as the length of the timing baseline the model is being used for - freqs : iterable of float or astropy.quantity.Quantity, None - User inputed base frequencies - n_freqs : int, None - Number of wave frequencies to calculate using the equation: freq_n = 2 * pi * n / T_span - Where n is the wave number, and T_span is the total time span of the toas in the fitter object - freeze_params : bool, optional - Whether the new parameters should be frozen - - Returns - ------- - - indices : list - Indices that have been assigned to new WaveX components - """ - from pint.models.cmwavex import CMWaveX - - if (freqs is None) and (n_freqs is None): - raise ValueError( - "CMWaveX component base frequencies are not specified. " - "Please input either freqs or n_freqs" - ) - - if (freqs is not None) and (n_freqs is not None): - raise ValueError( - "Both freqs and n_freqs are specified. Only one or the other should be used" - ) - - if n_freqs is not None and n_freqs <= 0: - raise ValueError("Must use a non-zero number of wave frequencies") - - model.add_component(CMWaveX()) - if isinstance(T_span, u.quantity.Quantity): - T_span.to(u.d) - else: - T_span *= u.d - - nyqist_freq = 1.0 / (2.0 * T_span) - if freqs is not None: - if isinstance(freqs, u.quantity.Quantity): - freqs.to(u.d**-1) - else: - freqs *= u.d**-1 - if len(freqs) == 1: - model.CMWXFREQ_0001.quantity = freqs - else: - freqs = np.array(freqs) - freqs.sort() - if min(np.diff(freqs)) < nyqist_freq: - warnings.warn( - "CMWaveX frequency spacing is finer than frequency resolution of data" - ) - model.CMWXFREQ_0001.quantity = freqs[0] - model.components["CMWaveX"].add_cmwavex_components(freqs[1:]) - - if n_freqs is not None: - if n_freqs == 1: - wave_freq = 1 / T_span - model.CMWXFREQ_0001.quantity = wave_freq - else: - wave_numbers = np.arange(1, n_freqs + 1) - wave_freqs = wave_numbers / T_span - model.CMWXFREQ_0001.quantity = wave_freqs[0] - model.components["CMWaveX"].add_cmwavex_components(wave_freqs[1:]) - - for p in model.params: - if p.startswith("CMWXSIN") or p.startswith("CMWXCOS"): - model[p].frozen = freeze_params - - return model.components["CMWaveX"].get_indices() - - -def _translate_wave_freqs(om: Union[float, u.Quantity], k: int) -> u.Quantity: - """ - Use Wave model WAVEOM parameter to calculate a WaveX WXFREQ_ frequency parameter for wave number k - - Parameters - ---------- - om : float or astropy.quantity.Quantity - Base frequency of Wave model solution - parameter WAVEOM - If float is given default units of 1/d assigned - k : int - wave number to use to calculate WaveX WXFREQ_ frequency parameter - - Returns - ------- - astropy.units.Quantity - WXFREQ_ quantity in units 1/d that can be used in WaveX model - """ - if isinstance(om, u.quantity.Quantity): - om.to(u.d**-1) - else: - om *= u.d**-1 - return (om * (k + 1)) / (2.0 * np.pi) - - -def _translate_wavex_freqs(wxfreq: Union[float, u.Quantity], k: int) -> u.Quantity: - """ - Use WaveX model WXFREQ_ parameters and wave number k to calculate the Wave model WAVEOM frequency parameter. - - Parameters - ---------- - wxfreq : float or astropy.quantity.Quantity - WaveX frequency from which the WAVEOM parameter will be calculated - If float is given default units of 1/d assigned - k : int - wave number to use to calculate Wave WAVEOM parameter - - Returns - ------- - astropy.units.Quantity - WAVEOM quantity in units 1/d that can be used in Wave model - """ - if isinstance(wxfreq, u.quantity.Quantity): - wxfreq.to(u.d**-1) - else: - wxfreq *= u.d**-1 - if len(wxfreq) == 1: - return (2.0 * np.pi * wxfreq) / (k + 1.0) - wave_om = [((2.0 * np.pi * wxfreq[i]) / (k[i] + 1.0)) for i in range(len(wxfreq))] - return ( - sum(wave_om) / len(wave_om) - if np.allclose(wave_om, wave_om[0], atol=1e-3) - else False - ) - - -def translate_wave_to_wavex( - model: "pint.models.TimingModel", -) -> "pint.models.TimingModel": - """ - Go from a Wave model to a WaveX model - - WaveX frequencies get calculated based on the Wave model WAVEOM parameter and the number of WAVE parameters. - WXFREQ_000k = [WAVEOM * (k+1)] / [2 * pi] - - WaveX amplitudes are taken from the WAVE pair parameters - - Paramters - --------- - model : pint.models.timing_model.TimingModel - TimingModel containing a Wave model to be converted to a WaveX model - - Returns - ------- - pint.models.timing_model.TimingModel - New timing model with converted WaveX model included - """ - from pint.models.wavex import WaveX - - new_model = deepcopy(model) - wave_names = [ - f"WAVE{ii}" for ii in range(1, model.components["Wave"].num_wave_terms + 1) - ] - wave_terms = [getattr(model.components["Wave"], name) for name in wave_names] - wave_om = model.components["Wave"].WAVE_OM.quantity - wave_epoch = model.components["Wave"].WAVEEPOCH.quantity - new_model.remove_component("Wave") - new_model.add_component(WaveX()) - new_model.WXEPOCH.value = wave_epoch.value - for k, wave_term in enumerate(wave_terms): - wave_sin_amp, wave_cos_amp = wave_term.quantity - wavex_freq = _translate_wave_freqs(wave_om, k) - if k == 0: - new_model.WXFREQ_0001.value = wavex_freq.value - new_model.WXSIN_0001.value = -wave_sin_amp.value - new_model.WXCOS_0001.value = -wave_cos_amp.value - else: - new_model.components["WaveX"].add_wavex_component( - wavex_freq, wxsin=-wave_sin_amp, wxcos=-wave_cos_amp - ) - return new_model - - -def get_wavex_freqs( - model: "pint.models.TimingModel", - index: Optional[Union[float, int, List, np.ndarray]] = None, - quantity: bool = False, -) -> List[Union[float, u.Quantity]]: - """ - Return the WaveX frequencies for a timing model. - - If index is specified, returns the frequencies corresponding to the user-provided indices. - If index isn't specified, returns all WaveX frequencies in timing model - - Parameters - ---------- - model : pint.models.timing_model.TimingModel - Timing model from which to return WaveX frequencies - index : float, int, list, np.ndarray, None - Number or list/array of numbers corresponding to WaveX frequencies to return - quantity : bool - If set to True, returns a list of astropy.quanitity.Quantity rather than a list of prefixParameters - - Returns - ------- - List of WXFREQ_ parameters - """ - if index is None: - freqs = model.components["WaveX"].get_prefix_mapping_component("WXFREQ_") - if len(freqs) == 1: - values = getattr(model.components["WaveX"], freqs.values()) - else: - values = [ - getattr(model.components["WaveX"], param) for param in freqs.values() - ] - elif isinstance(index, (int, float, np.int64)): - idx_rf = f"{int(index):04d}" - values = getattr(model.components["WaveX"], f"WXFREQ_{idx_rf}") - elif isinstance(index, (list, set, np.ndarray)): - idx_rf = [f"{int(idx):04d}" for idx in index] - values = [getattr(model.components["WaveX"], f"WXFREQ_{ind}") for ind in idx_rf] - else: - raise TypeError( - f"index most be a float, int, set, list, array, or None - not {type(index)}" - ) - if quantity: - if len(values) == 1: - values = [values[0].quantity] - else: - values = [v.quantity for v in values] - return values - - -def get_wavex_amps( - model: "pint.models.TimingModel", - index: Optional[Union[float, int, List, np.ndarray]] = None, - quantity: bool = False, -) -> List[Union[float, u.Quantity]]: - """ - Return the WaveX amplitudes for a timing model. - - If index is specified, returns the sine/cosine amplitudes corresponding to the user-provided indices. - If index isn't specified, returns all WaveX sine/cosine amplitudes in timing model - - Parameters - ---------- - model : pint.models.timing_model.TimingModel - Timing model from which to return WaveX frequencies - index : float, int, list, np.ndarray, None - Number or list/array of numbers corresponding to WaveX amplitudes to return - quantity : bool - If set to True, returns a list of tuples of astropy.quanitity.Quantity rather than a list of prefixParameters tuples - - Returns - ------- - List of WXSIN_ and WXCOS_ parameters - """ - if index is None: - indices = ( - model.components["WaveX"].get_prefix_mapping_component("WXSIN_").keys() - ) - if len(indices) == 1: - values = getattr( - model.components["WaveX"], f"WXSIN_{int(indices):04d}" - ), getattr(model.components["WaveX"], f"WXCOS_{int(indices):04d}") - else: - values = [ - ( - getattr(model.components["WaveX"], f"WXSIN_{int(idx):04d}"), - getattr(model.components["WaveX"], f"WXCOS_{int(idx):04d}"), - ) - for idx in indices - ] - elif isinstance(index, (int, float, np.int64)): - idx_rf = f"{int(index):04d}" - values = getattr(model.components["WaveX"], f"WXSIN_{idx_rf}"), getattr( - model.components["WaveX"], f"WXCOS_{idx_rf}" - ) - elif isinstance(index, (list, set, np.ndarray)): - idx_rf = [f"{int(idx):04d}" for idx in index] - values = [ - ( - getattr(model.components["WaveX"], f"WXSIN_{ind}"), - getattr(model.components["WaveX"], f"WXCOS_{ind}"), - ) - for ind in idx_rf - ] - else: - raise TypeError( - f"index most be a float, int, set, list, array, or None - not {type(index)}" - ) - if quantity: - if isinstance(values, tuple): - values = tuple(v.quantity for v in values) - if isinstance(values, list): - values = [(v[0].quantity, v[1].quantity) for v in values] - return values - - -def translate_wavex_to_wave( - model: "pint.models.TimingModel", -) -> "pint.models.TimingModel": - """ - Go from a WaveX timing model to a Wave timing model. - WARNING: Not every WaveX model can be appropriately translated into a Wave model. This is dependent on the user's choice of frequencies in the WaveX model. - In order for a WaveX model to be able to be converted into a Wave model, every WaveX frequency must produce the same value of WAVEOM in the calculation: - - WAVEOM = [2 * pi * WXFREQ_000k] / (k + 1) - Paramters - --------- - model : pint.models.timing_model.TimingModel - TimingModel containing a WaveX model to be converted to a Wave model - - Returns - ------- - pint.models.timing_model.TimingModel - New timing model with converted Wave model included - """ - from pint.models.wave import Wave - - new_model = deepcopy(model) - indices = model.components["WaveX"].get_indices() - wxfreqs = get_wavex_freqs(model, indices, quantity=True) - wave_om = _translate_wavex_freqs(wxfreqs, (indices - 1)) - if wave_om == False: - raise ValueError( - "This WaveX model cannot be properly translated into a Wave model due to the WaveX frequencies not producing a consistent WAVEOM value" - ) - wave_amps = get_wavex_amps(model, index=indices, quantity=True) - new_model.remove_component("WaveX") - new_model.add_component(Wave()) - new_model.WAVEEPOCH.quantity = model.WXEPOCH.quantity - new_model.WAVE_OM.quantity = wave_om - new_model.WAVE1.quantity = tuple(w * -1.0 for w in wave_amps[0]) - if len(indices) > 1: - for i in range(1, len(indices)): - print(wave_amps[i]) - wave_amps[i] = tuple(w * -1.0 for w in wave_amps[i]) - new_model.components["Wave"].add_wave_component( - wave_amps[i], index=indices[i] - ) - return new_model - - def weighted_mean( arrin: np.ndarray, weights_in: np.ndarray, @@ -3401,6 +2851,8 @@ def find_optimal_nharms( Array of normalized AIC values. """ from pint.fitter import Fitter + from pint.models.wavex import wavex_setup + from pint.models.dmwavex import dmwavex_setup assert component in ["WaveX", "DMWaveX"] assert ( diff --git a/tests/test_cmwavex.py b/tests/test_cmwavex.py index d0f2652fb..a3aba5869 100644 --- a/tests/test_cmwavex.py +++ b/tests/test_cmwavex.py @@ -4,8 +4,8 @@ from pint.models import get_model from pint.fitter import Fitter +from pint.models.cmwavex import cmwavex_setup from pint.simulation import make_fake_toas_uniform -from pint.utils import cmwavex_setup from pint.models.chromatic_model import cmu import pytest diff --git a/tests/test_dmwavex.py b/tests/test_dmwavex.py index 5fdaf0cba..0f74429b7 100644 --- a/tests/test_dmwavex.py +++ b/tests/test_dmwavex.py @@ -3,9 +3,9 @@ import numpy as np from pint.models import get_model +from pint.models.dmwavex import dmwavex_setup from pint.fitter import Fitter from pint.simulation import make_fake_toas_uniform -from pint.utils import dmwavex_setup from pint import dmu import pytest diff --git a/tests/test_wavex.py b/tests/test_wavex.py index 7c5c7f9f5..c20ab1205 100644 --- a/tests/test_wavex.py +++ b/tests/test_wavex.py @@ -10,9 +10,7 @@ from pint.fitter import Fitter from pint.residuals import Residuals from pint.simulation import make_fake_toas_uniform -import pint.utils -from pinttestdata import datadir -from pint.models.wavex import WaveX +from pint.models.wavex import WaveX, translate_wave_to_wavex, translate_wavex_to_wave par1 = """ PSR B1937+21 @@ -392,8 +390,8 @@ def test_wave_wavex_roundtrip_conversion(): model = get_model(StringIO(par1)) toas = make_fake_toas_uniform(55000, 55100, 500, model, obs="gbt") wave_model = get_model(StringIO(par1 + wave_par)) - wave_to_wavex_model = pint.utils.translate_wave_to_wavex(wave_model) - wavex_to_wave_model = pint.utils.translate_wavex_to_wave(wave_to_wavex_model) + wave_to_wavex_model = translate_wave_to_wavex(wave_model) + wavex_to_wave_model = translate_wavex_to_wave(wave_to_wavex_model) rs_wave = Residuals(toas, wave_model) rs_wave_to_wavex = Residuals(toas, wave_to_wavex_model) rs_wavex_to_wave = Residuals(toas, wavex_to_wave_model) diff --git a/tests/test_wx2pl.py b/tests/test_wx2pl.py index 8f426f2e3..f61a8e585 100644 --- a/tests/test_wx2pl.py +++ b/tests/test_wx2pl.py @@ -1,13 +1,13 @@ import pytest from pint.models import get_model +from pint.models.cmwavex import cmwavex_setup +from pint.models.dmwavex import dmwavex_setup +from pint.models.wavex import wavex_setup from pint.simulation import make_fake_toas_uniform from pint.fitter import WLSFitter from pint.utils import ( - cmwavex_setup, - dmwavex_setup, find_optimal_nharms, plchromnoise_from_cmwavex, - wavex_setup, plrednoise_from_wavex, pldmnoise_from_dmwavex, ) From d1dfcf81abd7544fb0b9df41296431ea07042460 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 9 Aug 2024 09:28:42 +0200 Subject: [PATCH 17/31] CHANGELOG --- CHANGELOG-unreleased.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG-unreleased.md b/CHANGELOG-unreleased.md index 318150c91..1c1f55dc6 100644 --- a/CHANGELOG-unreleased.md +++ b/CHANGELOG-unreleased.md @@ -18,6 +18,7 @@ the released changes. - Fourier series representation of chromatic noise (`CMWaveX`) - `pint.utils.cmwavex_setup` and `pint.utils.plchromnoise_from_cmwavex` functions - More validation for correlated noise components in `TimingModel` +- Moved functions related to `WaveX`, `DMWaveX`, and `CMWaveX` from `pint.utils` to their own namespaces within `pint.models` ### Fixed - Bug in `DMWaveX.get_indices()` function - Explicit type conversion in `woodbury_dot()` function From fa95894588bdd1e1884f130ea1bbc7a1b71b5cae Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 9 Aug 2024 09:28:51 +0200 Subject: [PATCH 18/31] find_optimal_nharms --- src/pint/noise_analysis.py | 112 +++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 src/pint/noise_analysis.py diff --git a/src/pint/noise_analysis.py b/src/pint/noise_analysis.py new file mode 100644 index 000000000..27a0692c2 --- /dev/null +++ b/src/pint/noise_analysis.py @@ -0,0 +1,112 @@ +from copy import deepcopy +from typing import List +from itertools import product as cartesian_product + +import numpy as np +from pint.models.chromatic_model import ChromaticCM +from pint.models.cmwavex import cmwavex_setup +from pint.models.dispersion_model import DispersionDM +from pint.models.dmwavex import dmwavex_setup +from pint.models.phase_offset import PhaseOffset +from pint.models.timing_model import TimingModel +from pint.models.wavex import wavex_setup +from pint.toa import TOAs +from pint.utils import ( + akaike_information_criterion, +) + + +def find_optimal_nharms( + model: TimingModel, + toas: TOAs, + include_components: List[str] = ["WaveX", "DMWaveX", "CMWaveX"], + nharms_max: int = 45, + chromatic_index: float = 4, +): + assert len(set(include_components).intersection(set(model.components.keys()))) == 0 + + idxs = list( + cartesian_product( + *np.repeat([np.arange(nharms_max + 1)], len(include_components), axis=0) + ) + ) + + aics = np.zeros(np.repeat(nharms_max, len(include_components))) + for ii in idxs: + aics[*ii] = compute_aic(model, toas, include_components, ii, chromatic_index) + + assert all(np.isfinite(aics)), "Infs/NaNs found in AICs!" + + aics -= np.min(aics) + + return aics, np.unravel_index(np.argmin(aics), aics.shape) + + +def compute_aic( + model: TimingModel, + toas: TOAs, + include_components: List[str], + ii: np.ndarray, + chromatic_index: float, +): + model1 = prepare_model(model, toas, include_components, ii, chromatic_index) + + from pint.fitter import Fitter + + ftr = Fitter.auto(toas, model1) + ftr.fit_toas(maxiter=10) + + return akaike_information_criterion(ftr.model, toas) + + +def prepare_model( + model: TimingModel, + toas: TOAs, + include_components: List[str], + nharms: np.ndarray, + chromatic_index: float, +): + model1 = deepcopy(model) + + Tspan = toas.get_Tspan() + + if "PhaseOffset" not in model1.components: + model1.add_component(PhaseOffset()) + model1.PHOFF.frozen = False + + if "DMWaveX" in include_components: + if "DispersionDM" not in model1.components: + model1.add_component(DispersionDM()) + + model1.DM.frozen = False + if model.DM1.quantity is None: + model.DM1.quantity = 0 * model.DM1.units + model1.DM1.frozen = False + + if "CMWaveX" in include_components: + if "ChromaticCM" not in model1.components: + model1.add_component(ChromaticCM()) + model1.TNCHROMIDX.value = chromatic_index + + model1.CM.frozen = False + if model.CM1.quantity is None: + model.CM1.quantity = 0 * model.CM1.units + model1.CM1.frozen = False + + for jj, comp in enumerate(include_components): + if comp == "WaveX": + nharms_wx = nharms[jj] + if nharms_wx > 0: + wavex_setup(model1, Tspan, n_freqs=nharms_wx, freeze_params=False) + elif comp == "DMWaveX": + nharms_dwx = nharms[jj] + if nharms_dwx > 0: + dmwavex_setup(model1, Tspan, n_freqs=nharms_dwx, freeze_params=False) + elif comp == "CMWaveX": + nharms_cwx = nharms[jj] + if nharms_cwx > 0: + cmwavex_setup(model1, Tspan, n_freqs=nharms_cwx, freeze_params=False) + else: + raise ValueError(f"Unsupported component {comp}.") + + return model1 From f51474d48ad71975ff8534914e7c728f88c849a4 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 9 Aug 2024 12:32:32 +0200 Subject: [PATCH 19/31] move function --- CHANGELOG-unreleased.md | 5 +- src/pint/noise_analysis.py | 248 +++++++++++++++++++++++++++++++++++- src/pint/utils.py | 252 ------------------------------------- 3 files changed, 250 insertions(+), 255 deletions(-) diff --git a/CHANGELOG-unreleased.md b/CHANGELOG-unreleased.md index 1c1f55dc6..17783d83b 100644 --- a/CHANGELOG-unreleased.md +++ b/CHANGELOG-unreleased.md @@ -11,14 +11,15 @@ the released changes. ### Changed - Moved the events -> TOAs and photon weights code into the function `load_events_weights` within `event_optimize`. - Updated the `maxMJD` argument in `event_optimize` to default to the current mjd +- More validation for correlated noise components in `TimingModel` +- Moved functions related to `WaveX`, `DMWaveX`, and `CMWaveX` from `pint.utils` to their own namespaces within `pint.models` +- Moved `plrednoise_from_wavex` and similar functions from `pint.utils` to `pint.noise_analysis` ### Added - Type hints in `pint.derived_quantities` - Doing `model.par = something` will try to assign to `par.quantity` or `par.value` but will give warning - `PLChromNoise` component to model chromatic red noise with a power law spectrum - Fourier series representation of chromatic noise (`CMWaveX`) - `pint.utils.cmwavex_setup` and `pint.utils.plchromnoise_from_cmwavex` functions -- More validation for correlated noise components in `TimingModel` -- Moved functions related to `WaveX`, `DMWaveX`, and `CMWaveX` from `pint.utils` to their own namespaces within `pint.models` ### Fixed - Bug in `DMWaveX.get_indices()` function - Explicit type conversion in `woodbury_dot()` function diff --git a/src/pint/noise_analysis.py b/src/pint/noise_analysis.py index 27a0692c2..74f63364d 100644 --- a/src/pint/noise_analysis.py +++ b/src/pint/noise_analysis.py @@ -1,8 +1,11 @@ from copy import deepcopy -from typing import List +from typing import List, Tuple from itertools import product as cartesian_product import numpy as np +from astropy import units as u +from scipy.optimize import minimize +from numdifftools import Hessian from pint.models.chromatic_model import ChromaticCM from pint.models.cmwavex import cmwavex_setup from pint.models.dispersion_model import DispersionDM @@ -16,6 +19,249 @@ ) +def _get_wx2pl_lnlike( + model: TimingModel, component_name: str, ignore_fyr: bool = True +) -> float: + from pint.models.noise_model import powerlaw + from pint import DMconst + + assert component_name in {"WaveX", "DMWaveX", "CMWaveX"} + prefix_dict = {"WaveX": "WX", "DMWaveX": "DMWX", "CMWaveX": "CMWX"} + prefix = prefix_dict[component_name] + + idxs = np.array(model.components[component_name].get_indices()) + + fs = np.array( + [model[f"{prefix}FREQ_{idx:04d}"].quantity.to_value(u.Hz) for idx in idxs] + ) + f0 = np.min(fs) + fyr = (1 / u.year).to_value(u.Hz) + + assert np.allclose( + np.diff(np.diff(fs)), 0 + ), "WaveX/DMWaveX/CMWaveX frequencies must be uniformly spaced for this conversion to work." + + if ignore_fyr: + year_mask = np.abs(((fs - fyr) / f0)) > 0.5 + + idxs = idxs[year_mask] + fs = np.array( + [model[f"{prefix}FREQ_{idx:04d}"].quantity.to_value(u.Hz) for idx in idxs] + ) + f0 = np.min(fs) + + scaling_factor = ( + 1 + if component_name == "WaveX" + else ( + DMconst / (1400 * u.MHz) ** 2 + if component_name == "DMWaveX" + else DMconst / 1400**model.TNCHROMIDX.value + ) + ) + + a = np.array( + [ + (scaling_factor * model[f"{prefix}SIN_{idx:04d}"].quantity).to_value(u.s) + for idx in idxs + ] + ) + da = np.array( + [ + (scaling_factor * model[f"{prefix}SIN_{idx:04d}"].uncertainty).to_value(u.s) + for idx in idxs + ] + ) + b = np.array( + [ + (scaling_factor * model[f"{prefix}COS_{idx:04d}"].quantity).to_value(u.s) + for idx in idxs + ] + ) + db = np.array( + [ + (scaling_factor * model[f"{prefix}COS_{idx:04d}"].uncertainty).to_value(u.s) + for idx in idxs + ] + ) + + def powl_model(params: Tuple[float, float]) -> float: + """Get the powerlaw spectrum for the WaveX frequencies for a given + set of parameters. This calls the powerlaw function used by `PLRedNoise`/`PLDMNoise`/`PLChromNoise`. + """ + gamma, log10_A = params + return (powerlaw(fs, A=10**log10_A, gamma=gamma) * f0) ** 0.5 + + def mlnlike(params: Tuple[float, ...]) -> float: + """Negative of the likelihood function that acts on the + `[DM/CM]WaveX` amplitudes.""" + sigma = powl_model(params) + return 0.5 * float( + np.sum( + (a**2 / (sigma**2 + da**2)) + + (b**2 / (sigma**2 + db**2)) + + np.log(sigma**2 + da**2) + + np.log(sigma**2 + db**2) + ) + ) + + return mlnlike + + +def plrednoise_from_wavex(model: TimingModel, ignore_fyr: bool = True) -> TimingModel: + """Convert a `WaveX` representation of red noise to a `PLRedNoise` + representation. This is done by minimizing a likelihood function + that acts on the `WaveX` amplitudes over the powerlaw spectral + parameters. + + Parameters + ---------- + model: pint.models.timing_model.TimingModel + The timing model with a `WaveX` component. + ignore_fyr: bool + Whether to ignore the frequency bin containinf 1 yr^-1 + while fitting for the spectral parameters. + + Returns + ------- + pint.models.timing_model.TimingModel + The timing model with a converted `PLRedNoise` component. + """ + from pint.models.noise_model import PLRedNoise + + mlnlike = _get_wx2pl_lnlike(model, "WaveX", ignore_fyr=ignore_fyr) + + result = minimize(mlnlike, [4, -13], method="Nelder-Mead") + if not result.success: + raise ValueError("Log-likelihood maximization failed to converge.") + + gamma_val, log10_A_val = result.x + + hess = Hessian(mlnlike) + gamma_err, log10_A_err = np.sqrt( + np.diag(np.linalg.pinv(hess((gamma_val, log10_A_val)))) + ) + + tnredc = len(model.components["WaveX"].get_indices()) + + model1 = deepcopy(model) + model1.remove_component("WaveX") + model1.add_component(PLRedNoise()) + model1.TNREDAMP.value = log10_A_val + model1.TNREDGAM.value = gamma_val + model1.TNREDC.value = tnredc + model1.TNREDAMP.uncertainty_value = log10_A_err + model1.TNREDGAM.uncertainty_value = gamma_err + + return model1 + + +def pldmnoise_from_dmwavex(model: TimingModel, ignore_fyr: bool = False) -> TimingModel: + """Convert a `DMWaveX` representation of red noise to a `PLDMNoise` + representation. This is done by minimizing a likelihood function + that acts on the `DMWaveX` amplitudes over the powerlaw spectral + parameters. + + Parameters + ---------- + model: pint.models.timing_model.TimingModel + The timing model with a `DMWaveX` component. + + Returns + ------- + pint.models.timing_model.TimingModel + The timing model with a converted `PLDMNoise` component. + """ + from pint.models.noise_model import PLDMNoise + + mlnlike = _get_wx2pl_lnlike(model, "DMWaveX", ignore_fyr=ignore_fyr) + + result = minimize(mlnlike, [4, -13], method="Nelder-Mead") + if not result.success: + raise ValueError("Log-likelihood maximization failed to converge.") + + gamma_val, log10_A_val = result.x + + hess = Hessian(mlnlike) + + H = hess((gamma_val, log10_A_val)) + assert np.all(np.linalg.eigvals(H) > 0), "The Hessian is not positive definite!" + + Hinv = np.linalg.pinv(H) + assert np.all( + np.linalg.eigvals(Hinv) > 0 + ), "The inverse Hessian is not positive definite!" + + gamma_err, log10_A_err = np.sqrt(np.diag(Hinv)) + + tndmc = len(model.components["DMWaveX"].get_indices()) + + model1 = deepcopy(model) + model1.remove_component("DMWaveX") + model1.add_component(PLDMNoise()) + model1.TNDMAMP.value = log10_A_val + model1.TNDMGAM.value = gamma_val + model1.TNDMC.value = tndmc + model1.TNDMAMP.uncertainty_value = log10_A_err + model1.TNDMGAM.uncertainty_value = gamma_err + + return model1 + + +def plchromnoise_from_cmwavex( + model: TimingModel, ignore_fyr: bool = False +) -> TimingModel: + """Convert a `CMWaveX` representation of red noise to a `PLChromNoise` + representation. This is done by minimizing a likelihood function + that acts on the `CMWaveX` amplitudes over the powerlaw spectral + parameters. + + Parameters + ---------- + model: pint.models.timing_model.TimingModel + The timing model with a `CMWaveX` component. + + Returns + ------- + pint.models.timing_model.TimingModel + The timing model with a converted `PLChromNoise` component. + """ + from pint.models.noise_model import PLChromNoise + + mlnlike = _get_wx2pl_lnlike(model, "CMWaveX", ignore_fyr=ignore_fyr) + + result = minimize(mlnlike, [4, -13], method="Nelder-Mead") + if not result.success: + raise ValueError("Log-likelihood maximization failed to converge.") + + gamma_val, log10_A_val = result.x + + hess = Hessian(mlnlike) + + H = hess((gamma_val, log10_A_val)) + assert np.all(np.linalg.eigvals(H) > 0), "The Hessian is not positive definite!" + + Hinv = np.linalg.pinv(H) + assert np.all( + np.linalg.eigvals(Hinv) > 0 + ), "The inverse Hessian is not positive definite!" + + gamma_err, log10_A_err = np.sqrt(np.diag(Hinv)) + + tndmc = len(model.components["CMWaveX"].get_indices()) + + model1 = deepcopy(model) + model1.remove_component("CMWaveX") + model1.add_component(PLChromNoise()) + model1.TNCHROMAMP.value = log10_A_val + model1.TNCHROMGAM.value = gamma_val + model1.TNCHROMC.value = tndmc + model1.TNCHROMAMP.uncertainty_value = log10_A_err + model1.TNCHROMGAM.uncertainty_value = gamma_err + + return model1 + + def find_optimal_nharms( model: TimingModel, toas: TOAs, diff --git a/src/pint/utils.py b/src/pint/utils.py index 883ee3dc0..bef0554fc 100644 --- a/src/pint/utils.py +++ b/src/pint/utils.py @@ -41,8 +41,6 @@ from contextlib import contextmanager from pathlib import Path from warnings import warn -from scipy.optimize import minimize -from numdifftools import Hessian from typing import ( Optional, List, @@ -69,7 +67,6 @@ from scipy.special import fdtrc from scipy.linalg import cho_factor, cho_solve from copy import deepcopy -import warnings import pint import pint.pulsar_ecliptic @@ -129,8 +126,6 @@ "bayesian_information_criterion", "sherman_morrison_dot", "woodbury_dot", - "plrednoise_from_wavex", - "pldmnoise_from_dmwavex", "find_optimal_nharms", ] @@ -2576,253 +2571,6 @@ def woodbury_dot( return x_Cinv_y, logdet_C -def _get_wx2pl_lnlike( - model: "pint.models.TimingModel", component_name: str, ignore_fyr: bool = True -) -> float: - from pint.models.noise_model import powerlaw - from pint import DMconst - - assert component_name in {"WaveX", "DMWaveX", "CMWaveX"} - prefix_dict = {"WaveX": "WX", "DMWaveX": "DMWX", "CMWaveX": "CMWX"} - prefix = prefix_dict[component_name] - - idxs = np.array(model.components[component_name].get_indices()) - - fs = np.array( - [model[f"{prefix}FREQ_{idx:04d}"].quantity.to_value(u.Hz) for idx in idxs] - ) - f0 = np.min(fs) - fyr = (1 / u.year).to_value(u.Hz) - - assert np.allclose( - np.diff(np.diff(fs)), 0 - ), "WaveX/DMWaveX/CMWaveX frequencies must be uniformly spaced for this conversion to work." - - if ignore_fyr: - year_mask = np.abs(((fs - fyr) / f0)) > 0.5 - - idxs = idxs[year_mask] - fs = np.array( - [model[f"{prefix}FREQ_{idx:04d}"].quantity.to_value(u.Hz) for idx in idxs] - ) - f0 = np.min(fs) - - scaling_factor = ( - 1 - if component_name == "WaveX" - else ( - DMconst / (1400 * u.MHz) ** 2 - if component_name == "DMWaveX" - else DMconst / 1400**model.TNCHROMIDX.value - ) - ) - - a = np.array( - [ - (scaling_factor * model[f"{prefix}SIN_{idx:04d}"].quantity).to_value(u.s) - for idx in idxs - ] - ) - da = np.array( - [ - (scaling_factor * model[f"{prefix}SIN_{idx:04d}"].uncertainty).to_value(u.s) - for idx in idxs - ] - ) - b = np.array( - [ - (scaling_factor * model[f"{prefix}COS_{idx:04d}"].quantity).to_value(u.s) - for idx in idxs - ] - ) - db = np.array( - [ - (scaling_factor * model[f"{prefix}COS_{idx:04d}"].uncertainty).to_value(u.s) - for idx in idxs - ] - ) - - def powl_model(params: Tuple[float, float]) -> float: - """Get the powerlaw spectrum for the WaveX frequencies for a given - set of parameters. This calls the powerlaw function used by `PLRedNoise`/`PLDMNoise`/`PLChromNoise`. - """ - gamma, log10_A = params - return (powerlaw(fs, A=10**log10_A, gamma=gamma) * f0) ** 0.5 - - def mlnlike(params: Tuple[float, ...]) -> float: - """Negative of the likelihood function that acts on the - `[DM/CM]WaveX` amplitudes.""" - sigma = powl_model(params) - return 0.5 * float( - np.sum( - (a**2 / (sigma**2 + da**2)) - + (b**2 / (sigma**2 + db**2)) - + np.log(sigma**2 + da**2) - + np.log(sigma**2 + db**2) - ) - ) - - return mlnlike - - -def plrednoise_from_wavex( - model: "pint.models.TimingModel", ignore_fyr: bool = True -) -> "pint.models.TimingModel": - """Convert a `WaveX` representation of red noise to a `PLRedNoise` - representation. This is done by minimizing a likelihood function - that acts on the `WaveX` amplitudes over the powerlaw spectral - parameters. - - Parameters - ---------- - model: pint.models.timing_model.TimingModel - The timing model with a `WaveX` component. - ignore_fyr: bool - Whether to ignore the frequency bin containinf 1 yr^-1 - while fitting for the spectral parameters. - - Returns - ------- - pint.models.timing_model.TimingModel - The timing model with a converted `PLRedNoise` component. - """ - from pint.models.noise_model import PLRedNoise - - mlnlike = _get_wx2pl_lnlike(model, "WaveX", ignore_fyr=ignore_fyr) - - result = minimize(mlnlike, [4, -13], method="Nelder-Mead") - if not result.success: - raise ValueError("Log-likelihood maximization failed to converge.") - - gamma_val, log10_A_val = result.x - - hess = Hessian(mlnlike) - gamma_err, log10_A_err = np.sqrt( - np.diag(np.linalg.pinv(hess((gamma_val, log10_A_val)))) - ) - - tnredc = len(model.components["WaveX"].get_indices()) - - model1 = deepcopy(model) - model1.remove_component("WaveX") - model1.add_component(PLRedNoise()) - model1.TNREDAMP.value = log10_A_val - model1.TNREDGAM.value = gamma_val - model1.TNREDC.value = tnredc - model1.TNREDAMP.uncertainty_value = log10_A_err - model1.TNREDGAM.uncertainty_value = gamma_err - - return model1 - - -def pldmnoise_from_dmwavex( - model: "pint.models.TimingModel", ignore_fyr: bool = False -) -> "pint.models.TimingModel": - """Convert a `DMWaveX` representation of red noise to a `PLDMNoise` - representation. This is done by minimizing a likelihood function - that acts on the `DMWaveX` amplitudes over the powerlaw spectral - parameters. - - Parameters - ---------- - model: pint.models.timing_model.TimingModel - The timing model with a `DMWaveX` component. - - Returns - ------- - pint.models.timing_model.TimingModel - The timing model with a converted `PLDMNoise` component. - """ - from pint.models.noise_model import PLDMNoise - - mlnlike = _get_wx2pl_lnlike(model, "DMWaveX", ignore_fyr=ignore_fyr) - - result = minimize(mlnlike, [4, -13], method="Nelder-Mead") - if not result.success: - raise ValueError("Log-likelihood maximization failed to converge.") - - gamma_val, log10_A_val = result.x - - hess = Hessian(mlnlike) - - H = hess((gamma_val, log10_A_val)) - assert np.all(np.linalg.eigvals(H) > 0), "The Hessian is not positive definite!" - - Hinv = np.linalg.pinv(H) - assert np.all( - np.linalg.eigvals(Hinv) > 0 - ), "The inverse Hessian is not positive definite!" - - gamma_err, log10_A_err = np.sqrt(np.diag(Hinv)) - - tndmc = len(model.components["DMWaveX"].get_indices()) - - model1 = deepcopy(model) - model1.remove_component("DMWaveX") - model1.add_component(PLDMNoise()) - model1.TNDMAMP.value = log10_A_val - model1.TNDMGAM.value = gamma_val - model1.TNDMC.value = tndmc - model1.TNDMAMP.uncertainty_value = log10_A_err - model1.TNDMGAM.uncertainty_value = gamma_err - - return model1 - - -def plchromnoise_from_cmwavex( - model: "pint.models.TimingModel", ignore_fyr: bool = False -) -> "pint.models.TimingModel": - """Convert a `CMWaveX` representation of red noise to a `PLChromNoise` - representation. This is done by minimizing a likelihood function - that acts on the `CMWaveX` amplitudes over the powerlaw spectral - parameters. - - Parameters - ---------- - model: pint.models.timing_model.TimingModel - The timing model with a `CMWaveX` component. - - Returns - ------- - pint.models.timing_model.TimingModel - The timing model with a converted `PLChromNoise` component. - """ - from pint.models.noise_model import PLChromNoise - - mlnlike = _get_wx2pl_lnlike(model, "CMWaveX", ignore_fyr=ignore_fyr) - - result = minimize(mlnlike, [4, -13], method="Nelder-Mead") - if not result.success: - raise ValueError("Log-likelihood maximization failed to converge.") - - gamma_val, log10_A_val = result.x - - hess = Hessian(mlnlike) - - H = hess((gamma_val, log10_A_val)) - assert np.all(np.linalg.eigvals(H) > 0), "The Hessian is not positive definite!" - - Hinv = np.linalg.pinv(H) - assert np.all( - np.linalg.eigvals(Hinv) > 0 - ), "The inverse Hessian is not positive definite!" - - gamma_err, log10_A_err = np.sqrt(np.diag(Hinv)) - - tndmc = len(model.components["CMWaveX"].get_indices()) - - model1 = deepcopy(model) - model1.remove_component("CMWaveX") - model1.add_component(PLChromNoise()) - model1.TNCHROMAMP.value = log10_A_val - model1.TNCHROMGAM.value = gamma_val - model1.TNCHROMC.value = tndmc - model1.TNCHROMAMP.uncertainty_value = log10_A_err - model1.TNCHROMGAM.uncertainty_value = gamma_err - - return model1 - - def find_optimal_nharms( model: "pint.models.TimingModel", toas: "pint.toa.TOAs", From b66d98949e8292677a025429aad2c9b12f10cfc6 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 9 Aug 2024 13:38:40 +0200 Subject: [PATCH 20/31] fix test --- src/pint/noise_analysis.py | 2 +- tests/test_wx2pl.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pint/noise_analysis.py b/src/pint/noise_analysis.py index 74f63364d..886d9373c 100644 --- a/src/pint/noise_analysis.py +++ b/src/pint/noise_analysis.py @@ -279,7 +279,7 @@ def find_optimal_nharms( aics = np.zeros(np.repeat(nharms_max, len(include_components))) for ii in idxs: - aics[*ii] = compute_aic(model, toas, include_components, ii, chromatic_index) + aics[ii] = compute_aic(model, toas, include_components, ii, chromatic_index) assert all(np.isfinite(aics)), "Infs/NaNs found in AICs!" diff --git a/tests/test_wx2pl.py b/tests/test_wx2pl.py index f61a8e585..ed658a285 100644 --- a/tests/test_wx2pl.py +++ b/tests/test_wx2pl.py @@ -5,8 +5,8 @@ from pint.models.wavex import wavex_setup from pint.simulation import make_fake_toas_uniform from pint.fitter import WLSFitter -from pint.utils import ( - find_optimal_nharms, +from pint.utils import find_optimal_nharms +from pint.noise_analysis import ( plchromnoise_from_cmwavex, plrednoise_from_wavex, pldmnoise_from_dmwavex, From b64d743b22d02b296ad38ae6197fff768b4e8be9 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 9 Aug 2024 14:50:45 +0200 Subject: [PATCH 21/31] test_prepare_model_for_find_optimal_nharms --- src/pint/models/timing_model.py | 3 +- src/pint/noise_analysis.py | 44 ++++++++++++++-------- tests/test_wx2pl.py | 67 +++++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 16 deletions(-) diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py index 3c8fd217e..28a8b72ab 100644 --- a/src/pint/models/timing_model.py +++ b/src/pint/models/timing_model.py @@ -33,6 +33,7 @@ import contextlib from collections import OrderedDict, defaultdict from functools import wraps +from typing import Dict from warnings import warn from uncertainties import ufloat @@ -807,7 +808,7 @@ def set_param_uncertainties(self, fitp): p.uncertainty = v if isinstance(v, u.Quantity) else v * p.units @property_exists - def components(self): + def components(self) -> Dict[str, "Component"]: """All the components in a dictionary indexed by name.""" comps = {} for ct in self.component_types: diff --git a/src/pint/noise_analysis.py b/src/pint/noise_analysis.py index 886d9373c..10fadf382 100644 --- a/src/pint/noise_analysis.py +++ b/src/pint/noise_analysis.py @@ -10,6 +10,7 @@ from pint.models.cmwavex import cmwavex_setup from pint.models.dispersion_model import DispersionDM from pint.models.dmwavex import dmwavex_setup +from pint.models.parameter import prefixParameter from pint.models.phase_offset import PhaseOffset from pint.models.timing_model import TimingModel from pint.models.wavex import wavex_setup @@ -295,7 +296,9 @@ def compute_aic( ii: np.ndarray, chromatic_index: float, ): - model1 = prepare_model(model, toas, include_components, ii, chromatic_index) + model1 = prepare_model( + model, toas.get_Tspan(), include_components, ii, chromatic_index + ) from pint.fitter import Fitter @@ -307,37 +310,48 @@ def compute_aic( def prepare_model( model: TimingModel, - toas: TOAs, + Tspan: u.Quantity, include_components: List[str], nharms: np.ndarray, chromatic_index: float, ): model1 = deepcopy(model) - Tspan = toas.get_Tspan() - if "PhaseOffset" not in model1.components: model1.add_component(PhaseOffset()) - model1.PHOFF.frozen = False + model1.PHOFF.frozen = False if "DMWaveX" in include_components: if "DispersionDM" not in model1.components: model1.add_component(DispersionDM()) - model1.DM.frozen = False - if model.DM1.quantity is None: - model.DM1.quantity = 0 * model.DM1.units - model1.DM1.frozen = False + model1["DM"].frozen = False + + if model1["DM1"].quantity is None: + model1["DM1"].quantity = 0 * model1["DM1"].units + model1["DM1"].frozen = False + + if "DM2" not in model1.params: + model1.components["DispersionDM"].add_param(model["DM1"].new_param(2)) + if model1["DM2"].quantity is None: + model1["DM2"].quantity = 0 * model1["DM2"].units + model1["DM2"].frozen = False if "CMWaveX" in include_components: if "ChromaticCM" not in model1.components: model1.add_component(ChromaticCM()) - model1.TNCHROMIDX.value = chromatic_index - - model1.CM.frozen = False - if model.CM1.quantity is None: - model.CM1.quantity = 0 * model.CM1.units - model1.CM1.frozen = False + model1["TNCHROMIDX"].value = chromatic_index + + model1["CM"].frozen = False + if model1["CM1"].quantity is None: + model1["CM1"].quantity = 0 * model1["CM1"].units + model1["CM1"].frozen = False + + if "CM2" not in model1.params: + model1.components["ChromaticCM"].add_param(model1["CM1"].new_param(2)) + if model1["CM2"].quantity is None: + model1["CM2"].quantity = 0 * model1["CM2"].units + model1["CM2"].frozen = False for jj, comp in enumerate(include_components): if comp == "WaveX": diff --git a/tests/test_wx2pl.py b/tests/test_wx2pl.py index ed658a285..0664f83f7 100644 --- a/tests/test_wx2pl.py +++ b/tests/test_wx2pl.py @@ -10,6 +10,7 @@ plchromnoise_from_cmwavex, plrednoise_from_wavex, pldmnoise_from_dmwavex, + prepare_model, ) from io import StringIO @@ -245,3 +246,69 @@ def test_find_optimal_nharms_dmwx(data_dmwx): assert nharm <= 7 assert np.all(aics >= 0) + + +@pytest.fixture +def model_sim3(): + par_sim = """ + PSR SIM3 + RAJ 05:00:00 1 + DECJ 15:00:00 1 + PEPOCH 55000 + F0 100 1 + F1 -1e-15 1 + PHOFF 0 1 + DM 15 1 + TZRMJD 55000 + TZRFRQ 1400 + TZRSITE gbt + UNITS TDB + EPHEM DE440 + CLOCK TT(BIPM2019) + """ + return get_model(StringIO(par_sim)) + + +@pytest.mark.parametrize( + "component_names", + [ + ["WaveX"], + ["DMWaveX"], + ["CMWaveX"], + ["WaveX", "DMWaveX"], + ["WaveX", "CMWaveX"], + ["DMWaveX", "CMWaveX"], + ["WaveX", "DMWaveX", "CMWaveX"], + ], +) +def test_prepare_model_for_find_optimal_nharms(model_sim3, component_names): + m0 = model_sim3 + + m = prepare_model( + model=m0, + Tspan=10 * u.year, + include_components=component_names, + nharms=np.repeat(10, len(component_names)), + chromatic_index=4, + ) + + assert "PHOFF" in m.free_params + + assert ("WaveX" in component_names) == ("WaveX" in m.components) + assert ("DMWaveX" in component_names) == ("DMWaveX" in m.components) + assert ("CMWaveX" in component_names) == ("CMWaveX" in m.components) + + if "WaveX" in component_names: + assert len(m.components["WaveX"].get_indices()) == 10 + + if "DMWaveX" in component_names: + assert not m.DM.frozen and m.DM.quantity is not None + assert not m.DM1.frozen and m.DM1.quantity is not None + assert not m.DM2.frozen and m.DM2.quantity is not None + assert len(m.components["DMWaveX"].get_indices()) == 10 + + if "CMWaveX" in component_names: + assert not m.CM.frozen and m.CM.quantity is not None + assert not m.CM1.frozen and m.CM1.quantity is not None + assert not m.CM2.frozen and m.CM2.quantity is not None + assert len(m.components["CMWaveX"].get_indices()) == 10 From 37877a4191a39a72a75b7e5c92c1e1596a37c7d6 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 9 Aug 2024 15:05:43 +0200 Subject: [PATCH 22/31] refactor --- src/pint/noise_analysis.py | 62 ++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/src/pint/noise_analysis.py b/src/pint/noise_analysis.py index 10fadf382..365a7e566 100644 --- a/src/pint/noise_analysis.py +++ b/src/pint/noise_analysis.py @@ -321,48 +321,46 @@ def prepare_model( model1.add_component(PhaseOffset()) model1.PHOFF.frozen = False - if "DMWaveX" in include_components: - if "DispersionDM" not in model1.components: - model1.add_component(DispersionDM()) - - model1["DM"].frozen = False - - if model1["DM1"].quantity is None: - model1["DM1"].quantity = 0 * model1["DM1"].units - model1["DM1"].frozen = False - - if "DM2" not in model1.params: - model1.components["DispersionDM"].add_param(model["DM1"].new_param(2)) - if model1["DM2"].quantity is None: - model1["DM2"].quantity = 0 * model1["DM2"].units - model1["DM2"].frozen = False - - if "CMWaveX" in include_components: - if "ChromaticCM" not in model1.components: - model1.add_component(ChromaticCM()) - model1["TNCHROMIDX"].value = chromatic_index - - model1["CM"].frozen = False - if model1["CM1"].quantity is None: - model1["CM1"].quantity = 0 * model1["CM1"].units - model1["CM1"].frozen = False - - if "CM2" not in model1.params: - model1.components["ChromaticCM"].add_param(model1["CM1"].new_param(2)) - if model1["CM2"].quantity is None: - model1["CM2"].quantity = 0 * model1["CM2"].units - model1["CM2"].frozen = False - for jj, comp in enumerate(include_components): if comp == "WaveX": nharms_wx = nharms[jj] if nharms_wx > 0: wavex_setup(model1, Tspan, n_freqs=nharms_wx, freeze_params=False) elif comp == "DMWaveX": + if "DispersionDM" not in model1.components: + model1.add_component(DispersionDM()) + + model1["DM"].frozen = False + + if model1["DM1"].quantity is None: + model1["DM1"].quantity = 0 * model1["DM1"].units + model1["DM1"].frozen = False + + if "DM2" not in model1.params: + model1.components["DispersionDM"].add_param(model["DM1"].new_param(2)) + if model1["DM2"].quantity is None: + model1["DM2"].quantity = 0 * model1["DM2"].units + model1["DM2"].frozen = False + nharms_dwx = nharms[jj] if nharms_dwx > 0: dmwavex_setup(model1, Tspan, n_freqs=nharms_dwx, freeze_params=False) elif comp == "CMWaveX": + if "ChromaticCM" not in model1.components: + model1.add_component(ChromaticCM()) + model1["TNCHROMIDX"].value = chromatic_index + + model1["CM"].frozen = False + if model1["CM1"].quantity is None: + model1["CM1"].quantity = 0 * model1["CM1"].units + model1["CM1"].frozen = False + + if "CM2" not in model1.params: + model1.components["ChromaticCM"].add_param(model1["CM1"].new_param(2)) + if model1["CM2"].quantity is None: + model1["CM2"].quantity = 0 * model1["CM2"].units + model1["CM2"].frozen = False + nharms_cwx = nharms[jj] if nharms_cwx > 0: cmwavex_setup(model1, Tspan, n_freqs=nharms_cwx, freeze_params=False) From a47b4f29b8a578bc1451dd51ea913dc507e2f1e4 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 9 Aug 2024 15:17:18 +0200 Subject: [PATCH 23/31] fix epochs --- src/pint/noise_analysis.py | 14 +++++++++- tests/test_wx2pl.py | 53 +++++++++++++++++++++----------------- 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/src/pint/noise_analysis.py b/src/pint/noise_analysis.py index 365a7e566..dd41e9a31 100644 --- a/src/pint/noise_analysis.py +++ b/src/pint/noise_analysis.py @@ -302,7 +302,9 @@ def compute_aic( from pint.fitter import Fitter - ftr = Fitter.auto(toas, model1) + # Downhill fitters don't work well here. + # TODO: Investigate this. + ftr = Fitter.auto(toas, model1, downhill=False) ftr.fit_toas(maxiter=10) return akaike_information_criterion(ftr.model, toas) @@ -317,6 +319,10 @@ def prepare_model( ): model1 = deepcopy(model) + for comp in ["PLRedNoise", "PLDMNoise", "PLCMNoise"]: + if comp in model1.components: + model1.remove_component(comp) + if "PhaseOffset" not in model1.components: model1.add_component(PhaseOffset()) model1.PHOFF.frozen = False @@ -342,6 +348,9 @@ def prepare_model( model1["DM2"].quantity = 0 * model1["DM2"].units model1["DM2"].frozen = False + if model1["DMEPOCH"].quantity is None: + model1["DMEPOCH"].quantity = model1["PEPOCH"].quantity + nharms_dwx = nharms[jj] if nharms_dwx > 0: dmwavex_setup(model1, Tspan, n_freqs=nharms_dwx, freeze_params=False) @@ -361,6 +370,9 @@ def prepare_model( model1["CM2"].quantity = 0 * model1["CM2"].units model1["CM2"].frozen = False + if model1["CMEPOCH"].quantity is None: + model1["CMEPOCH"].quantity = model1["PEPOCH"].quantity + nharms_cwx = nharms[jj] if nharms_cwx > 0: cmwavex_setup(model1, Tspan, n_freqs=nharms_cwx, freeze_params=False) diff --git a/tests/test_wx2pl.py b/tests/test_wx2pl.py index 0664f83f7..56d353ad3 100644 --- a/tests/test_wx2pl.py +++ b/tests/test_wx2pl.py @@ -7,6 +7,7 @@ from pint.fitter import WLSFitter from pint.utils import find_optimal_nharms from pint.noise_analysis import ( + compute_aic, plchromnoise_from_cmwavex, plrednoise_from_wavex, pldmnoise_from_dmwavex, @@ -248,27 +249,6 @@ def test_find_optimal_nharms_dmwx(data_dmwx): assert np.all(aics >= 0) -@pytest.fixture -def model_sim3(): - par_sim = """ - PSR SIM3 - RAJ 05:00:00 1 - DECJ 15:00:00 1 - PEPOCH 55000 - F0 100 1 - F1 -1e-15 1 - PHOFF 0 1 - DM 15 1 - TZRMJD 55000 - TZRFRQ 1400 - TZRSITE gbt - UNITS TDB - EPHEM DE440 - CLOCK TT(BIPM2019) - """ - return get_model(StringIO(par_sim)) - - @pytest.mark.parametrize( "component_names", [ @@ -281,8 +261,8 @@ def model_sim3(): ["WaveX", "DMWaveX", "CMWaveX"], ], ) -def test_prepare_model_for_find_optimal_nharms(model_sim3, component_names): - m0 = model_sim3 +def test_prepare_model_for_find_optimal_nharms(data_dmwx, component_names): + m0, t = data_dmwx m = prepare_model( model=m0, @@ -292,6 +272,8 @@ def test_prepare_model_for_find_optimal_nharms(model_sim3, component_names): chromatic_index=4, ) + assert "PLDMNoise" not in m.components + assert "PHOFF" in m.free_params assert ("WaveX" in component_names) == ("WaveX" in m.components) @@ -312,3 +294,28 @@ def test_prepare_model_for_find_optimal_nharms(model_sim3, component_names): assert not m.CM1.frozen and m.CM1.quantity is not None assert not m.CM2.frozen and m.CM2.quantity is not None assert len(m.components["CMWaveX"].get_indices()) == 10 + + +@pytest.mark.parametrize( + "component_names", + [ + ["WaveX"], + ["DMWaveX"], + ["CMWaveX"], + ["WaveX", "DMWaveX"], + ["WaveX", "CMWaveX"], + ["DMWaveX", "CMWaveX"], + ["WaveX", "DMWaveX", "CMWaveX"], + ], +) +def test_compute_aic(data_dmwx, component_names): + m, t = data_dmwx + assert np.isfinite( + compute_aic( + m, + t, + include_components=component_names, + ii=np.array((8, 9, 10)), + chromatic_index=4, + ) + ) From 03f0de623b121dc32f6b975880cc05f518bd3ffb Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 9 Aug 2024 15:25:42 +0200 Subject: [PATCH 24/31] test_noise_analysis --- tests/{test_wx2pl.py => test_noise_analysis.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_wx2pl.py => test_noise_analysis.py} (100%) diff --git a/tests/test_wx2pl.py b/tests/test_noise_analysis.py similarity index 100% rename from tests/test_wx2pl.py rename to tests/test_noise_analysis.py From 829a5ba39f1372dcfc18c99d77ad8a34b34895c2 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 9 Aug 2024 15:36:32 +0200 Subject: [PATCH 25/31] don't subtract aic_min --- src/pint/noise_analysis.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/pint/noise_analysis.py b/src/pint/noise_analysis.py index dd41e9a31..baac6939b 100644 --- a/src/pint/noise_analysis.py +++ b/src/pint/noise_analysis.py @@ -284,8 +284,6 @@ def find_optimal_nharms( assert all(np.isfinite(aics)), "Infs/NaNs found in AICs!" - aics -= np.min(aics) - return aics, np.unravel_index(np.argmin(aics), aics.shape) From 5fbdc5f3a019e3dcb607130fd4ad540b52ceb0b2 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 9 Aug 2024 16:44:25 +0200 Subject: [PATCH 26/31] parallel --- requirements.txt | 1 + src/pint/noise_analysis.py | 96 ++++++++++++++++++++++---------------- 2 files changed, 57 insertions(+), 40 deletions(-) diff --git a/requirements.txt b/requirements.txt index 16f76401e..554392cdf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ uncertainties loguru nestle>=0.2.0 numdifftools +joblib \ No newline at end of file diff --git a/src/pint/noise_analysis.py b/src/pint/noise_analysis.py index baac6939b..5957710ce 100644 --- a/src/pint/noise_analysis.py +++ b/src/pint/noise_analysis.py @@ -1,7 +1,8 @@ from copy import deepcopy -from typing import List, Tuple +from typing import List, Optional, Tuple from itertools import product as cartesian_product +from joblib import Parallel, cpu_count, delayed import numpy as np from astropy import units as u from scipy.optimize import minimize @@ -14,10 +15,10 @@ from pint.models.phase_offset import PhaseOffset from pint.models.timing_model import TimingModel from pint.models.wavex import wavex_setup +from pint.polycos import tqdm from pint.toa import TOAs -from pint.utils import ( - akaike_information_criterion, -) +from pint.utils import akaike_information_criterion +from pint.logging import setup as setup_log def _get_wx2pl_lnlike( @@ -269,6 +270,7 @@ def find_optimal_nharms( include_components: List[str] = ["WaveX", "DMWaveX", "CMWaveX"], nharms_max: int = 45, chromatic_index: float = 4, + num_parallel_jobs: Optional[int] = None, ): assert len(set(include_components).intersection(set(model.components.keys()))) == 0 @@ -278,9 +280,17 @@ def find_optimal_nharms( ) ) - aics = np.zeros(np.repeat(nharms_max, len(include_components))) - for ii in idxs: - aics[ii] = compute_aic(model, toas, include_components, ii, chromatic_index) + if num_parallel_jobs is None: + num_parallel_jobs = cpu_count() + + aics_flat = Parallel(n_jobs=num_parallel_jobs)( + delayed( + lambda ii: compute_aic(model, toas, include_components, ii, chromatic_index) + )(ii) + for ii in idxs + ) + + aics = np.reshape(aics_flat, [nharms_max + 1] * len(include_components)) assert all(np.isfinite(aics)), "Infs/NaNs found in AICs!" @@ -294,6 +304,8 @@ def compute_aic( ii: np.ndarray, chromatic_index: float, ): + setup_log(level="WARNING") + model1 = prepare_model( model, toas.get_Tspan(), include_components, ii, chromatic_index ) @@ -331,48 +343,52 @@ def prepare_model( if nharms_wx > 0: wavex_setup(model1, Tspan, n_freqs=nharms_wx, freeze_params=False) elif comp == "DMWaveX": - if "DispersionDM" not in model1.components: - model1.add_component(DispersionDM()) + nharms_dwx = nharms[jj] + if nharms_dwx > 0: + if "DispersionDM" not in model1.components: + model1.add_component(DispersionDM()) - model1["DM"].frozen = False + model1["DM"].frozen = False - if model1["DM1"].quantity is None: - model1["DM1"].quantity = 0 * model1["DM1"].units - model1["DM1"].frozen = False + if model1["DM1"].quantity is None: + model1["DM1"].quantity = 0 * model1["DM1"].units + model1["DM1"].frozen = False - if "DM2" not in model1.params: - model1.components["DispersionDM"].add_param(model["DM1"].new_param(2)) - if model1["DM2"].quantity is None: - model1["DM2"].quantity = 0 * model1["DM2"].units - model1["DM2"].frozen = False + if "DM2" not in model1.params: + model1.components["DispersionDM"].add_param( + model["DM1"].new_param(2) + ) + if model1["DM2"].quantity is None: + model1["DM2"].quantity = 0 * model1["DM2"].units + model1["DM2"].frozen = False - if model1["DMEPOCH"].quantity is None: - model1["DMEPOCH"].quantity = model1["PEPOCH"].quantity + if model1["DMEPOCH"].quantity is None: + model1["DMEPOCH"].quantity = model1["PEPOCH"].quantity - nharms_dwx = nharms[jj] - if nharms_dwx > 0: dmwavex_setup(model1, Tspan, n_freqs=nharms_dwx, freeze_params=False) elif comp == "CMWaveX": - if "ChromaticCM" not in model1.components: - model1.add_component(ChromaticCM()) - model1["TNCHROMIDX"].value = chromatic_index - - model1["CM"].frozen = False - if model1["CM1"].quantity is None: - model1["CM1"].quantity = 0 * model1["CM1"].units - model1["CM1"].frozen = False - - if "CM2" not in model1.params: - model1.components["ChromaticCM"].add_param(model1["CM1"].new_param(2)) - if model1["CM2"].quantity is None: - model1["CM2"].quantity = 0 * model1["CM2"].units - model1["CM2"].frozen = False - - if model1["CMEPOCH"].quantity is None: - model1["CMEPOCH"].quantity = model1["PEPOCH"].quantity - nharms_cwx = nharms[jj] if nharms_cwx > 0: + if "ChromaticCM" not in model1.components: + model1.add_component(ChromaticCM()) + model1["TNCHROMIDX"].value = chromatic_index + + model1["CM"].frozen = False + if model1["CM1"].quantity is None: + model1["CM1"].quantity = 0 * model1["CM1"].units + model1["CM1"].frozen = False + + if "CM2" not in model1.params: + model1.components["ChromaticCM"].add_param( + model1["CM1"].new_param(2) + ) + if model1["CM2"].quantity is None: + model1["CM2"].quantity = 0 * model1["CM2"].units + model1["CM2"].frozen = False + + if model1["CMEPOCH"].quantity is None: + model1["CMEPOCH"].quantity = model1["PEPOCH"].quantity + cmwavex_setup(model1, Tspan, n_freqs=nharms_cwx, freeze_params=False) else: raise ValueError(f"Unsupported component {comp}.") From 6402740c5c0f9537f893a29796431a026167b28d Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Sun, 11 Aug 2024 19:41:33 +0200 Subject: [PATCH 27/31] black --- src/pint/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pint/logging.py b/src/pint/logging.py index 980d1ed77..f92fb4bbf 100644 --- a/src/pint/logging.py +++ b/src/pint/logging.py @@ -134,7 +134,7 @@ class LogFilter: Define some messages that are never seen (e.g., Deprecation Warnings). Others that will only be seen once. Filtering of those is done on the basis of regular expressions. """ - + def __init__( self, onlyonce: Optional[List[str]] = None, From 524fc8f2e14510155820359d6ca841599a20beb2 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Sun, 11 Aug 2024 19:56:09 +0200 Subject: [PATCH 28/31] setup.cfg --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 28dfa2974..c0b74e937 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,7 @@ install_requires = loguru nestle>=0.2.0 numdifftools + joblib [options.packages.find] where = src From 1dde7d610775f95416e3faa298a922230259b0d5 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Mon, 12 Aug 2024 09:22:20 +0200 Subject: [PATCH 29/31] float --- src/pint/noise_analysis.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/pint/noise_analysis.py b/src/pint/noise_analysis.py index 5957710ce..0b9dbb811 100644 --- a/src/pint/noise_analysis.py +++ b/src/pint/noise_analysis.py @@ -98,13 +98,11 @@ def mlnlike(params: Tuple[float, ...]) -> float: """Negative of the likelihood function that acts on the `[DM/CM]WaveX` amplitudes.""" sigma = powl_model(params) - return 0.5 * float( - np.sum( - (a**2 / (sigma**2 + da**2)) - + (b**2 / (sigma**2 + db**2)) - + np.log(sigma**2 + da**2) - + np.log(sigma**2 + db**2) - ) + return 0.5 * np.sum( + (a**2 / (sigma**2 + da**2)) + + (b**2 / (sigma**2 + db**2)) + + np.log(sigma**2 + da**2) + + np.log(sigma**2 + db**2) ) return mlnlike From 95702e292d55f03d8e07b5afd1da8bfd48f6d34a Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Mon, 12 Aug 2024 10:26:48 +0200 Subject: [PATCH 30/31] wavex_setup --- src/pint/models/wavex.py | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/src/pint/models/wavex.py b/src/pint/models/wavex.py index f3b47a717..f1b184619 100644 --- a/src/pint/models/wavex.py +++ b/src/pint/models/wavex.py @@ -438,41 +438,38 @@ def wavex_setup( """ from pint.models.wavex import WaveX - if (freqs is None) and (n_freqs is None): - raise ValueError( - "WaveX component base frequencies are not specified. " - "Please input either freqs or n_freqs" - ) - - if (freqs is not None) and (n_freqs is not None): - raise ValueError( - "Both freqs and n_freqs are specified. Only one or the other should be used" - ) + assert (freqs is not None) or ( + n_freqs is not None + ), "WaveX component base frequencies are not specified. Please input either freqs or n_freqs" + assert (freqs is None) or ( + n_freqs is None + ), "Both freqs and n_freqs are specified. Only one or the other should be used" - if n_freqs is not None and n_freqs <= 0: - raise ValueError("Must use a non-zero number of wave frequencies") + assert ( + n_freqs is None or n_freqs > 0 + ), "Must use a non-zero number of wave frequencies" model.add_component(WaveX()) - if isinstance(T_span, u.quantity.Quantity): - T_span.to(u.d) - else: + + if not isinstance(T_span, u.quantity.Quantity): T_span *= u.d nyqist_freq = 1.0 / (2.0 * T_span) if freqs is not None: - if isinstance(freqs, u.quantity.Quantity): - freqs.to(u.d**-1) - else: + if not isinstance(freqs, u.quantity.Quantity): freqs *= u.d**-1 + if len(freqs) == 1: model.WXFREQ_0001.quantity = freqs else: - freqs = np.array(freqs) - freqs.sort() if min(np.diff(freqs)) < nyqist_freq: warnings.warn( "Wave frequency spacing is finer than frequency resolution of data" ) + + freqs = np.array(freqs) + freqs.sort() + model.WXFREQ_0001.quantity = freqs[0] model.components["WaveX"].add_wavex_components(freqs[1:]) From 2a5321d8f7a509851d2ece3a8b6f75e7db0de18b Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Mon, 12 Aug 2024 10:45:51 +0200 Subject: [PATCH 31/31] tests --- src/pint/models/cmwavex.py | 35 +++++++++++++++-------------------- src/pint/models/dmwavex.py | 35 +++++++++++++++-------------------- tests/test_cmwavex.py | 30 +++++++++++++++++++++++++++++- tests/test_dmwavex.py | 28 +++++++++++++++++++++++++++- tests/test_wavex.py | 18 +++++++++++++++++- 5 files changed, 103 insertions(+), 43 deletions(-) diff --git a/src/pint/models/cmwavex.py b/src/pint/models/cmwavex.py index 4a686036f..f307e7659 100644 --- a/src/pint/models/cmwavex.py +++ b/src/pint/models/cmwavex.py @@ -428,41 +428,36 @@ def cmwavex_setup( """ from pint.models.cmwavex import CMWaveX - if (freqs is None) and (n_freqs is None): - raise ValueError( - "CMWaveX component base frequencies are not specified. " - "Please input either freqs or n_freqs" - ) - - if (freqs is not None) and (n_freqs is not None): - raise ValueError( - "Both freqs and n_freqs are specified. Only one or the other should be used" - ) + assert (freqs is not None) or ( + n_freqs is not None + ), "CMWaveX component base frequencies are not specified. Please input either freqs or n_freqs" + assert (freqs is None) or ( + n_freqs is None + ), "Both freqs and n_freqs are specified. Only one or the other should be used" - if n_freqs is not None and n_freqs <= 0: - raise ValueError("Must use a non-zero number of wave frequencies") + assert ( + n_freqs is None or n_freqs > 0 + ), "Must use a non-zero number of wave frequencies" model.add_component(CMWaveX()) - if isinstance(T_span, u.quantity.Quantity): - T_span.to(u.d) - else: + if not isinstance(T_span, u.quantity.Quantity): T_span *= u.d nyqist_freq = 1.0 / (2.0 * T_span) if freqs is not None: - if isinstance(freqs, u.quantity.Quantity): - freqs.to(u.d**-1) - else: + if not isinstance(freqs, u.quantity.Quantity): freqs *= u.d**-1 if len(freqs) == 1: model.CMWXFREQ_0001.quantity = freqs else: - freqs = np.array(freqs) - freqs.sort() if min(np.diff(freqs)) < nyqist_freq: warnings.warn( "CMWaveX frequency spacing is finer than frequency resolution of data" ) + + freqs = np.array(freqs) + freqs.sort() + model.CMWXFREQ_0001.quantity = freqs[0] model.components["CMWaveX"].add_cmwavex_components(freqs[1:]) diff --git a/src/pint/models/dmwavex.py b/src/pint/models/dmwavex.py index fc758680e..39fa5b0e6 100644 --- a/src/pint/models/dmwavex.py +++ b/src/pint/models/dmwavex.py @@ -428,41 +428,36 @@ def dmwavex_setup( """ from pint.models.dmwavex import DMWaveX - if (freqs is None) and (n_freqs is None): - raise ValueError( - "DMWaveX component base frequencies are not specified. " - "Please input either freqs or n_freqs" - ) - - if (freqs is not None) and (n_freqs is not None): - raise ValueError( - "Both freqs and n_freqs are specified. Only one or the other should be used" - ) + assert (freqs is not None) or ( + n_freqs is not None + ), "DMWaveX component base frequencies are not specified. Please input either freqs or n_freqs" + assert (freqs is None) or ( + n_freqs is None + ), "Both freqs and n_freqs are specified. Only one or the other should be used" - if n_freqs is not None and n_freqs <= 0: - raise ValueError("Must use a non-zero number of wave frequencies") + assert ( + n_freqs is None or n_freqs > 0 + ), "Must use a non-zero number of wave frequencies" model.add_component(DMWaveX()) - if isinstance(T_span, u.quantity.Quantity): - T_span.to(u.d) - else: + if not isinstance(T_span, u.quantity.Quantity): T_span *= u.d nyqist_freq = 1.0 / (2.0 * T_span) if freqs is not None: - if isinstance(freqs, u.quantity.Quantity): - freqs.to(u.d**-1) - else: + if not isinstance(freqs, u.quantity.Quantity): freqs *= u.d**-1 if len(freqs) == 1: model.DMWXFREQ_0001.quantity = freqs else: - freqs = np.array(freqs) - freqs.sort() if min(np.diff(freqs)) < nyqist_freq: warnings.warn( "DMWaveX frequency spacing is finer than frequency resolution of data" ) + + freqs = np.array(freqs) + freqs.sort() + model.DMWXFREQ_0001.quantity = freqs[0] model.components["DMWaveX"].add_dmwavex_components(freqs[1:]) diff --git a/tests/test_cmwavex.py b/tests/test_cmwavex.py index a3aba5869..57d0420f7 100644 --- a/tests/test_cmwavex.py +++ b/tests/test_cmwavex.py @@ -31,7 +31,7 @@ def test_cmwavex(): """ m = get_model(StringIO(par)) - with pytest.raises(ValueError): + with pytest.raises(AssertionError): idxs = cmwavex_setup(m, 3600) idxs = cmwavex_setup(m, 3600, n_freqs=5) @@ -175,3 +175,31 @@ def test_add_cmwavex(): cmwxcoses=[0, 0], frozens=[False, False, False], ) + + +def test_cmwavex_setup(): + par1 = """ + PSR B1937+21 + LAMBDA 301.9732445337270 + BETA 42.2967523367957 + PMLAMBDA -0.0175 + PMBETA -0.3971 + PX 0.1515 + POSEPOCH 55321.0000 + F0 641.9282333345536244 1 0.0000000000000132 + F1 -4.330899370129D-14 1 2.149749089617D-22 + PEPOCH 55321.000000 + DM 71.016633 + CM 0.1 + TNCHROMIDX 4 + UNITS TDB + """ + + m = get_model(StringIO(par1)) + cmwavex_setup(m, T_span=3 * u.year, n_freqs=1) + + m = get_model(StringIO(par1)) + cmwavex_setup(m, T_span=3 * u.year, n_freqs=10) + + m = get_model(StringIO(par1)) + cmwavex_setup(m, T_span=3 * u.year, freqs=np.linspace(1, 10, 10) / u.year) diff --git a/tests/test_dmwavex.py b/tests/test_dmwavex.py index 0f74429b7..d0b099e35 100644 --- a/tests/test_dmwavex.py +++ b/tests/test_dmwavex.py @@ -30,7 +30,7 @@ def test_dmwavex(): m = get_model(StringIO(par)) - with pytest.raises(ValueError): + with pytest.raises(AssertionError): idxs = dmwavex_setup(m, 3600) idxs = dmwavex_setup(m, 3600, n_freqs=5) @@ -167,3 +167,29 @@ def test_add_dmwavex(): dmwxcoses=[0, 0], frozens=[False, False, False], ) + + +def test_dmwavex_setup(): + par1 = """ + PSR B1937+21 + LAMBDA 301.9732445337270 + BETA 42.2967523367957 + PMLAMBDA -0.0175 + PMBETA -0.3971 + PX 0.1515 + POSEPOCH 55321.0000 + F0 641.9282333345536244 1 0.0000000000000132 + F1 -4.330899370129D-14 1 2.149749089617D-22 + PEPOCH 55321.000000 + DM 71.016633 + UNITS TDB + """ + + m = get_model(StringIO(par1)) + dmwavex_setup(m, T_span=3 * u.year, n_freqs=1) + + m = get_model(StringIO(par1)) + dmwavex_setup(m, T_span=3 * u.year, n_freqs=10) + + m = get_model(StringIO(par1)) + dmwavex_setup(m, T_span=3 * u.year, freqs=np.linspace(1, 10, 10) / u.year) diff --git a/tests/test_wavex.py b/tests/test_wavex.py index c20ab1205..77030e4f5 100644 --- a/tests/test_wavex.py +++ b/tests/test_wavex.py @@ -10,7 +10,12 @@ from pint.fitter import Fitter from pint.residuals import Residuals from pint.simulation import make_fake_toas_uniform -from pint.models.wavex import WaveX, translate_wave_to_wavex, translate_wavex_to_wave +from pint.models.wavex import ( + WaveX, + translate_wave_to_wavex, + translate_wavex_to_wave, + wavex_setup, +) par1 = """ PSR B1937+21 @@ -397,3 +402,14 @@ def test_wave_wavex_roundtrip_conversion(): rs_wavex_to_wave = Residuals(toas, wavex_to_wave_model) assert np.allclose(rs_wave.resids, rs_wave_to_wavex.resids, atol=1e-3) assert np.allclose(rs_wave.resids, rs_wavex_to_wave.resids, atol=1e-3) + + +def test_wavex_setup(): + m = get_model(StringIO(par1)) + wavex_setup(m, T_span=3 * u.year, n_freqs=1) + + m = get_model(StringIO(par1)) + wavex_setup(m, T_span=3 * u.year, n_freqs=10) + + m = get_model(StringIO(par1)) + wavex_setup(m, T_span=3 * u.year, freqs=np.linspace(1, 10, 10) / u.year)