Skip to content

Commit

Permalink
locked, unlock
Browse files Browse the repository at this point in the history
  • Loading branch information
abhisrkckl committed Sep 10, 2023
1 parent 70df575 commit 10e490a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
8 changes: 7 additions & 1 deletion src/pint/models/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1974,11 +1974,17 @@ def select_toa_mask(self, toas):

def cache_toa_mask(self, toas):
"""Cache the TOA mask for this parameter. This will lead to exceptions and/or
wrong results if the TOAs object is mutated after this.
wrong results if the TOAs object is mutated after this. This can be undone using
the `clear_toa_mask_cache()` method.
"""
self.__cache_mask = self.select_toa_mask(toas)
self.__cache_toas = toas

def clear_toa_mask_cache(self):
"""Clear the TOA mask cache. This will undo a `cache_toa_mask()` call."""
self.__cache_mask = None
self.__cache_toas = None

def compare_key_value(self, other_param):
"""Compare if the key and value are the same with the other parameter.
Expand Down
28 changes: 23 additions & 5 deletions src/pint/models/timing_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,11 @@ def lock(self, toas):
"undefined behavior if the TOAs object is mutated after this."
)

if (toas.get_errors() == 0).any():
raise ValueError(
"Model cannot be locked since some TOAs have zero uncertainties."
)

self.__locked = True
self.__lock_toas = toas

Expand All @@ -507,6 +512,19 @@ def lock(self, toas):
if mpar.key is not None:
mpar.cache_toa_mask(toas)

@property
def locked(self):
return self.__locked

def unlock(self):
self.__locked = False
self.__lock_toas = None

for mp in self.get_params_of_type_top("maskParameter"):
mpar = getattr(self, mp)
if mpar.key is not None:
mpar.clear_toa_mask_cache()

def __getattr__(self, name):
if name in ["components", "component_types", "search_cmp_attr"]:
raise AttributeError
Expand Down Expand Up @@ -1451,15 +1469,15 @@ def scaled_toa_uncertainty(self, toas):
"""
ntoa = toas.ntoas
tbl = toas.table
result = np.zeros(ntoa) * u.us
result = np.zeros(ntoa)

# When there is no noise model.
if len(self.scaled_toa_uncertainty_funcs) == 0:
result += tbl["error"].quantity
return result
return tbl["error"].quantity.to(u.us)

for nf in self.scaled_toa_uncertainty_funcs:
result += nf(toas)
return result
result += nf(toas).to_value(u.us)
return result << u.us

def scaled_dm_uncertainty(self, toas):
"""Get the scaled DM data uncertainties noise models.
Expand Down
11 changes: 9 additions & 2 deletions tests/test_mask_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pinttestdata import datadir
from pint.models import get_model_and_toas
from pint.residuals import Residuals
import pytest


def test_mask_cache():
Expand All @@ -12,8 +13,14 @@ def test_mask_cache():
chi2_1 = res1.calc_chi2()

model.lock(toas)

res2 = Residuals(toas, model)
chi2_2 = res2.calc_chi2()

assert chi2_1 == chi2_2

model.unlock()
chi2_3 = res1.calc_chi2()
assert chi2_1 == chi2_3

toas.table["error"][0] = 0
with pytest.raises(ValueError):
model.lock(toas)

0 comments on commit 10e490a

Please sign in to comment.