Skip to content

Commit

Permalink
Cleaning up the test suite
Browse files Browse the repository at this point in the history
  • Loading branch information
poneill129 committed Sep 12, 2023
1 parent bce039c commit 03894b2
Show file tree
Hide file tree
Showing 3 changed files with 359 additions and 428 deletions.
220 changes: 92 additions & 128 deletions src/pint/models/binary_bt.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def remove_range(self, index):
index : float, int, list, np.ndarray
Number or list/array of numbers corresponding to T0X/A1X indices to be removed from model.
"""

if (
isinstance(index, int)
or isinstance(index, float)
Expand All @@ -220,7 +221,17 @@ def remove_range(self, index):
for index in indices:
index_rf = f"{int(index):04d}"
for prefix in ["T0X_", "A1X_", "XR1_", "XR2_"]:
self.remove_param(prefix + index_rf)
if hasattr(self, f"{prefix+index_rf}"):
self.remove_param(prefix + index_rf)
if hasattr(self.binary_instance, "param_pieces"):
if len(self.binary_instance.param_pieces) > 0:
temp_piece_information = []
for item in self.binary_instance.param_pieces:
if item[0] != index_rf:
temp_piece_information.append(item)
self.binary_instance.param_pieces = temp_piece_information
# self.binary_instance.param_pieces = self.binary_instance.param_pieces.remove('index_rf')

self.validate()
self.setup()

Expand All @@ -239,19 +250,31 @@ def add_piecewise_param(self, piece_index, **kwargs):
if key in kwargs:
param = key
paramx = kwargs[key]
if isinstance(paramx, u.quantity.Quantity):
paramx = paramx.value
elif isinstance(paramx, np.float128):
paramx = paramx
else:
raise ValueError(
"Unspported data type '%s'. Ensure the piecewise parameter value is a np.float128 or astropy.quantity.Quantity"
% type(paramx)
)

if key == "T0":
param_unit = u.d
if isinstance(paramx, u.quantity.Quantity):
paramx = paramx.value
elif isinstance(paramx, np.float128):
paramx = paramx
elif isinstance(paramx, Time):
paramx = paramx.mjd
else:
raise ValueError(
"Unspported data type '%s' for piecewise T0. Ensure the piecewise parameter value is a np.float128, Time or astropy.quantity.Quantity"
% type(paramx)
)
elif key == "A1":
param_unit = ls
if isinstance(paramx, u.quantity.Quantity):
paramx = paramx.value
elif isinstance(paramx, np.float64):
paramx = paramx
else:
raise ValueError(
"Unspported data type '%s' for piecewise A1. Ensure the piecewise parameter value is a np.float64 or astropy.quantity.Quantity"
% type(paramx)
)
key_found = True

if not key_found:
Expand Down Expand Up @@ -280,7 +303,7 @@ def add_piecewise_param(self, piece_index, **kwargs):
param_unit = (getattr(self, param)).units
if paramx is None:
paramx = (getattr(self, param)).value
# check if name exits and is currently available
# check if name exists and is currently available

self.add_param(
prefixParameter(
Expand All @@ -294,11 +317,6 @@ def add_piecewise_param(self, piece_index, **kwargs):
)
self.setup()

def lock_groups(self):
self.validate()
self.update_binary_object(None)
self.setup()

def setup(self):
"""Raises
------
Expand Down Expand Up @@ -422,16 +440,18 @@ def validate(self):
raise ValueError(
f"Group index mismatch error. Check the indexes of XR1_/XR2_ parameters in the model"
)
if len(ls_pub) > 0 and len(ls_T0X) > 0:
if len(np.setdiff1d(j_pub, j_T0X)) > 0:
raise ValueError(
f"Group index mismatch error. Check the indexes of T0X groups, make sure they match there are corresponding group ranges (XR1/XR2)"
)
if len(ls_pub) > 0 and len(ls_A1X) > 0:
if len(np.setdiff1d(j_pub, j_A1X)) > 0:
raise ValueError(
f"Group index mismatch error. Check the indexes of A1X groups, make sure they match there are corresponding group ranges (/XR2)"
)
if not len(ls_A1X) > 0:
if len(ls_pub) > 0 and len(ls_T0X) > 0:
if len(np.setdiff1d(j_pub, j_T0X)) > 0:
raise ValueError(
f"Group index mismatch error. Check the indexes of T0X groups, make sure they match there are corresponding group ranges (XR1/XR2)"
)
if not len(ls_T0X) > 0:
if len(ls_pub) > 0 and len(ls_A1X) > 0:
if len(np.setdiff1d(j_pub, j_A1X)) > 0:
raise ValueError(
f"Group index mismatch error. Check the indexes of A1X groups, make sure they match there are corresponding group ranges (/XR2)"
)
lb = [(getattr(self, tup[1])).value for tup in ls_plb]
ub = [(getattr(self, tup[1])).value for tup in ls_pub]

Expand All @@ -444,6 +464,17 @@ def validate(self):
)

def paramx_per_toa(self, param_name, toas):
"""Find the piecewise parameter value each toa will reference during calculations
Parameters
----------
param_name : string
which piecewise parameter to show: 'A1'/'T0'. TODO this should raise an error if not present)
toa : pint.toa.TOA
Returns
-------
u.quantity.Quantity
length(toa) elements are T0X or A1X values to reference for each toa during binary calculations.
"""
condition = {}
tbl = toas.table
XR1_mapping = self.get_prefix_mapping_component("XR1_")
Expand All @@ -455,13 +486,12 @@ def paramx_per_toa(self, param_name, toas):
param_unit = u.d
elif param_name[0:2] == "A1":
paramX_mapping = self.get_prefix_mapping_component("A1X_")
param_unit = u.ls
param_unit = ls
else:
raise AttributeError(
"param '%s' not found. Please choose another. Currently implemented: 'T0' or 'A1' "
% param_name
)

for piece_index in paramX_mapping.keys():
r1 = getattr(self, XR1_mapping[piece_index]).quantity
r2 = getattr(self, XR2_mapping[piece_index]).quantity
Expand All @@ -470,109 +500,43 @@ def paramx_per_toa(self, param_name, toas):
paramx = np.zeros(len(tbl)) * param_unit
for k, v in select_idx.items():
paramx[v] += getattr(self, k).quantity
for i in range(len(paramx)):
if paramx[i] == 0:
paramx[i] = (getattr(self, param_name[0:2])).value * param_unit

return paramx

def get_number_of_groups(self):
"""Get the number of piecewise parameters"""
return len(self.binary_instance.piecewise_parameter_information)

# def get_group_boundaries(self):
# """Get a all pieces' date boundaries.
# Returns
# -------
# list
# np.array
# (length: toas) List of piecewise orbital parameter lower boundaries
# np.array
# (length: toas) List of piecewise orbital parameter upper boundaries
# """
# # asks the object for the number of piecewise groups
# return self.binary_instance.get_group_boundaries()

# def which_group_is_toa_in(self, toa):
# """Find the group a toa belongs to based on the boundaries of groups passed to BT_piecewise
# Parameters
# ----------
# toa : toa
# TOA/TOAs to check which group they're in
# Returns
# -------
# np.array
# str elements, look like ['0000','0001'] for two TOAs where one refences T0X_0000 or T0X_0001.
# """
# return self.binary_instance.toa_belongs_in_group(toa)

# def get_group_indexes(self):
# """Get all the piecewise parameter labels
# Returns
# -------
# np.array
# (length: number of piecewise groups) List of piecewise parameter labels e.g with pieces T0X_0000, T0X_0001, T0X_0003, returns [0,1,3]
# """
# group_indexes = []
# for i in range(0, len(self.binary_instance.piecewise_parameter_information)):
# group_indexes.append(
# self.binary_instance.piecewise_parameter_information[i][0]
# )
# return group_indexes

# def get_T0Xs_associated_with_toas(self, toas):
# """Get a of all the piecewise T0s associated with TOAs
# Parameters
# ----------
# toas :
# Barycentric TOAs
# Returns
# -------
# np.array
# (length: toas) List of piecewise T0X values being used for each TOA
# """
# if hasattr(self.binary_instance, "group_index_array"):
# temporary_storage = self.binary_instance.group_index_array
# self.binary_instance.group_index_array = self.which_group_is_toa_in(toas)
# barycentric_toa = self._parent.get_barycentric_toas(toas)
# T0X_per_toa = self.binary_instance.piecewise_parameter_from_information_array(
# toas
# )[0]
# if temporary_storage is not None:
# self.binary_instance.group_index_array = temporary_storage
# return T0X_per_toa

# def get_A1Xs_associated_with_toas(self, toas):
# """Get a of all the piecewise A1s associated with TOAs
# Parameters
# ----------
# toas :
# Barycentric TOAs
# Returns
# -------
# np.array
# (length: toas) List of piecewise A1X values being used for each TOA
# """
# if hasattr(self.binary_instance, "group_index_array"):
# temporary_storage = self.binary_instance.group_index_array
# self.binary_instance.group_index_array = self.which_group_is_toa_in(toas)
# barycentric_toa = self._parent.get_barycentric_toas(toas)
# A1X_per_toa = self.binary_instance.piecewise_parameter_from_information_array(
# toas
# )[1]
# if temporary_storage is not None:
# self.binary_instance.group_index_array = temporary_storage
# return A1X_per_toa

# def does_toa_reference_piecewise_parameter(self, toas, param):
# """Query whether a TOA/list of TOAs belong(s) to a specific group
# Parameters
# ----------
# toas :
# Barycentric TOAs
# param : str
# Orbital piecewise parameter alias e.g. "T0X_0001" or "A1X_0001"
# Returns
# -------
# np.array
# boolean array (length: toas). True where toa is within piece boundaries corresponding to param
# """
# self.binary_instance.group_index_array = self.which_group_is_toa_in(toas)
# from_in_piece = self.binary_instance.in_piece(param)
# return from_in_piece[0]
def which_group_is_toa_in(self, toa):
"""Find the group a toa belongs to based on the boundaries of groups passed to BT_piecewise
Parameters
----------
Returns
-------
list
str elements, look like ['0000','0001'] for two TOAs where one refences T0X/A1X.
"""
# if isinstance(toa, pint.toa.TOAs):
# pass
# else:
# raise TypeError(f'toa must be a Time or pint.toa.TOAs - not {type(toa)}')

tbl = toa.table
condition = {}
XR1_mapping = self.get_prefix_mapping_component("XR1_")
XR2_mapping = self.get_prefix_mapping_component("XR2_")
if not hasattr(self, "toas_selector"):
self.toas_selector = TOASelect(is_range=True)
boundaries = {}
for piece_index in XR1_mapping.keys():
r1 = getattr(self, XR1_mapping[piece_index]).quantity
r2 = getattr(self, XR2_mapping[piece_index]).quantity
condition[(XR1_mapping[piece_index]).split("_")[-1]] = (r1.mjd, r2.mjd)
select_idx = self.toas_selector.get_select_index(condition, tbl["mjd_float"])
paramx = np.empty(len(tbl), dtype="<U4")
for k, v in select_idx.items():
paramx[v] = k
return paramx.tolist()
Loading

0 comments on commit 03894b2

Please sign in to comment.