diff --git a/src/pint/models/__init__.py b/src/pint/models/__init__.py index f5a46d3b9..e08e69965 100644 --- a/src/pint/models/__init__.py +++ b/src/pint/models/__init__.py @@ -20,7 +20,7 @@ # Import all standard model components here from pint.models.astrometry import AstrometryEcliptic, AstrometryEquatorial -from pint.models.binary_bt import BinaryBT +from pint.models.binary_bt import BinaryBT, BinaryBTPiecewise from pint.models.binary_dd import BinaryDD, BinaryDDS, BinaryDDGR from pint.models.binary_ddk import BinaryDDK from pint.models.binary_ell1 import BinaryELL1, BinaryELL1H, BinaryELL1k diff --git a/src/pint/models/binary_bt.py b/src/pint/models/binary_bt.py index 2797ca5cc..6a40ae9f4 100644 --- a/src/pint/models/binary_bt.py +++ b/src/pint/models/binary_bt.py @@ -1,9 +1,22 @@ """The BT (Blandford & Teukolsky) model.""" - +import numpy as np from pint.models.parameter import floatParameter from pint.models.pulsar_binary import PulsarBinary from pint.models.stand_alone_psr_binaries.BT_model import BTmodel -from pint.models.timing_model import MissingParameter +from pint.models.stand_alone_psr_binaries.BT_piecewise import BTpiecewise +from pint.models.timing_model import MissingParameter, TimingModel +import astropy.units as u +from pint import GMsun, Tsun, ls +from astropy.table import Table +from astropy.time import Time +from pint.models.parameter import ( + MJDParameter, + floatParameter, + prefixParameter, + maskParameter, +) + +from pint.toa_select import TOASelect class BinaryBT(PulsarBinary): @@ -66,3 +79,464 @@ def validate(self): if self.GAMMA.value is None: self.GAMMA.value = "0" self.GAMMA.frozen = True + + +"""The BT (Blandford & Teukolsky) model with piecewise orbital parameters. +See Blandford & Teukolsky 1976, ApJ, 205, 580. +""" + + +class BinaryBTPiecewise(PulsarBinary): + """Model implementing the BT model with piecewise orbital parameters A1X and T0X. This model lets the user specify time ranges and fit for a different piecewise orbital parameter in each time range, + This is a PINT pulsar binary BTPiecewise model class, a subclass of PulsarBinary. + It is a wrapper for stand alone BTPiecewise class defined in + ./stand_alone_psr_binary/BT_piecewise.py + The aim for this class is to connect the stand alone binary model with the PINT platform. + BTpiecewise special parameters, where xxxx denotes the 4-digit index of the piece: + T0X_xxxx Piecewise T0 values for piece + A1X_xxxx Piecewise A1 values for piece + XR1_xxxx Lower time boundary of piece + XR2_xxxx Upper time boundary of piece + """ + + register = True + + def __init__(self): + super(BinaryBTPiecewise, self).__init__() + self.binary_model_name = "BT_piecewise" + self.binary_model_class = BTpiecewise + self.add_param( + floatParameter( + name="GAMMA", + value=0.0, + units="second", + description="Time dilation & gravitational redshift", + ) + ) + self.A1_value_funcs = [] + self.T0_value_funcs = [] + self.remove_param("M2") + self.remove_param("SINI") + self.add_group_range(None, None) + self.add_piecewise_param(0, T0=0 * u.d) + self.add_piecewise_param(0, A1=0 * ls) + + def add_group_range( + self, + group_start_mjd, + group_end_mjd, + piece_index=None, + ): + """Add an orbital piecewise parameter group range. If piece_index is not provided a new piece will be added with index equal to the number of pieces plus one. Pieces cannot have the duplicate pieces and cannot have the same index. A pair of consisting of a piecewise A1 and T0 may share an index and will act over the same piece range. + Parameters + ---------- + group_start_mjd : float or astropy.quantity.Quantity or astropy.time.Time + MJD for the piece lower boundary + group_end_mjd : float or astropy.quantity.Quantity or astropy.time.Time + MJD for the piece upper boundary + piece_index : int + Number to label the piece being added. + """ + if group_start_mjd is not None and group_end_mjd is not None: + if isinstance(group_start_mjd, Time): + group_start_mjd = group_start_mjd.mjd + elif isinstance(group_start_mjd, u.quantity.Quantity): + group_start_mjd = group_start_mjd.value + if isinstance(group_end_mjd, Time): + group_end_mjd = group_end_mjd.mjd + elif isinstance(group_end_mjd, u.quantity.Quantity): + group_end_mjd = group_end_mjd.value + + elif group_start_mjd is None or group_end_mjd is None: + if group_start_mjd is None and group_end_mjd is not None: + group_start_mjd = group_end_mjd - 100 + elif group_start_mjd is not None and group_end_mjd is None: + group_end_mjd = group_start_mjd + 100 + else: + group_start_mjd = 50000 + group_end_mjd = 60000 + + if piece_index is None: + dct = self.get_prefix_mapping_component("XR1_") + if len(list(dct.keys())) > 0: + piece_index = np.max(list(dct.keys())) + 1 + else: + piece_index = 0 + + # check the validity of the desired group to add + + if group_end_mjd is not None and group_start_mjd is not None: + if group_end_mjd <= group_start_mjd: + raise ValueError("Starting MJD is greater than ending MJD.") + elif piece_index < 0: + raise ValueError( + f"Invalid index for group: {piece_index} should be greater than or equal to 0" + ) + elif piece_index > 9999: + raise ValueError( + f"Invalid index for group. Cannot index beyond 9999 (yet?)" + ) + + i = f"{int(piece_index):04d}" + self.add_param( + prefixParameter( + name="XR1_{0}".format(i), + units="MJD", + description="Beginning of paramX interval", + parameter_type="MJD", + time_scale="utc", + value=group_start_mjd, + ) + ) + self.add_param( + prefixParameter( + name="XR2_{0}".format(i), + units="MJD", + description="End of paramX interval", + parameter_type="MJD", + time_scale="utc", + value=group_end_mjd, + ) + ) + self.setup() + + def remove_range(self, index): + """Removes all orbital piecewise parameters associated with a given index/list of indices. + Parameters + ---------- + index : float, int, list, np.ndarray + Number or list/array of numbers corresponding to T0X/A1X indices to be removed from model. + """ + + if ( + isinstance(index, int) + or isinstance(index, float) + or isinstance(index, np.int64) + ): + indices = [index] + elif not isinstance(index, list) or not isinstance(index, np.ndarray): + raise TypeError( + f"index must be a float, int, list, or array - not {type(index)}" + ) + for index in indices: + index_rf = f"{int(index):04d}" + for prefix in ["T0X_", "A1X_", "XR1_", "XR2_"]: + if hasattr(self, f"{prefix+index_rf}"): + self.remove_param(prefix + index_rf) + if hasattr(self.binary_instance, "param_pieces"): + if len(self.binary_instance.param_pieces) > 0: + temp_piece_information = [] + for item in self.binary_instance.param_pieces: + if item[0] != index_rf: + temp_piece_information.append(item) + self.binary_instance.param_pieces = temp_piece_information + # self.binary_instance.param_pieces = self.binary_instance.param_pieces.remove('index_rf') + + self.validate() + self.setup() + + def add_piecewise_param(self, piece_index, **kwargs): + """Add an orbital piecewise parameter. + Parameters + ---------- + piece_index : int + Number to label the piece being added. Expected to match a set of piece boundaries. + param : str + Piecewise parameter label e.g. "T0" or "A1". + paramx : np.float128 or astropy.quantity.Quantity + Piecewise parameter value. + """ + for key in ("T0", "A1"): + if key in kwargs: + param = key + paramx = kwargs[key] + + if key == "T0": + param_unit = u.d + if isinstance(paramx, u.quantity.Quantity): + paramx = paramx.value + elif isinstance(paramx, np.float128): + paramx = paramx + elif isinstance(paramx, Time): + paramx = paramx.mjd + else: + raise ValueError( + "Unspported data type '%s' for piecewise T0. Ensure the piecewise parameter value is a np.float128, Time or astropy.quantity.Quantity" + % type(paramx) + ) + elif key == "A1": + param_unit = ls + if isinstance(paramx, u.quantity.Quantity): + paramx = paramx.value + elif isinstance(paramx, np.float64): + paramx = paramx + else: + raise ValueError( + "Unspported data type '%s' for piecewise A1. Ensure the piecewise parameter value is a np.float64 or astropy.quantity.Quantity" + % type(paramx) + ) + key_found = True + + if not key_found: + raise AttributeError( + "No piecewise parameters passed. Use T0 = / A1 = to declare a piecewise variable." + ) + + if piece_index is None: + dct = self.get_prefix_mapping_component(param + "X_") + if len(list(dct.keys())) > 0: + piece_index = np.max(list(dct.keys())) + 1 + else: + piece_index = 0 + elif int(piece_index) in self.get_prefix_mapping_component(param + "X_"): + raise ValueError( + "Index '%s' is already in use in this model. Please choose another." + % piece_index + ) + i = f"{int(piece_index):04d}" + + # handling if None are passed as arguments + if any(i is None for i in [param, param_unit, paramx]): + if param is not None: + # if parameter value or unit unset, set with default according to param + if param_unit is None: + param_unit = (getattr(self, param)).units + if paramx is None: + paramx = (getattr(self, param)).value + # check if name exists and is currently available + + self.add_param( + prefixParameter( + name=param + f"X_{i}", + units=param_unit, + value=paramx, + description="Parameter" + param + "variation", + parameter_type="float", + frozen=False, + ) + ) + self.setup() + + def setup(self): + """Raises + ------ + ValueError + if there are values that have been added without name/ranges associated (should only be raised if add_piecewise_param has been side-stepped with an alternate method) + """ + super().setup() + for bpar in self.params: + self.register_deriv_funcs(self.d_binary_delay_d_xxxx, bpar) + # Setup the model isinstance + self.binary_instance = self.binary_model_class() + # piecewise T0's + T0X_mapping = self.get_prefix_mapping_component("T0X_") + T0Xs = {} + # piecewise A1's (doing piecewise A1's requires more thought and work) + A1X_mapping = self.get_prefix_mapping_component("A1X_") + A1Xs = {} + # piecewise parameter ranges XR1-piece lower bound + XR1_mapping = self.get_prefix_mapping_component("XR1_") + XR1s = {} + # piecewise parameter ranges XR2-piece upper bound + XR2_mapping = self.get_prefix_mapping_component("XR2_") + XR2s = {} + + for index in XR1_mapping.values(): + index = index.split("_")[1] + piece_index = f"{int(index):04d}" + if hasattr(self, f"T0X_{piece_index}"): + if getattr(self, f"T0X_{piece_index}") is not None: + self.binary_instance.add_binary_params( + f"T0X_{piece_index}", getattr(self, f"T0X_{piece_index}") + ) + else: + self.binary_instance.add_binary_params( + f"T0X_{piece_index}", self.T0.value + ) + + if hasattr(self, f"A1X_{piece_index}"): + if hasattr(self, f"A1X_{piece_index}"): + if getattr(self, f"A1X_{piece_index}") is not None: + self.binary_instance.add_binary_params( + f"A1X_{piece_index}", getattr(self, f"A1X_{piece_index}") + ) + else: + self.binary_instance.add_binary_params( + f"A1X_{piece_index}", self.A1.value + ) + + if hasattr(self, f"XR1_{piece_index}"): + if getattr(self, f"XR1_{piece_index}") is not None: + self.binary_instance.add_binary_params( + f"XR1_{piece_index}", getattr(self, f"XR1_{piece_index}") + ) + else: + raise ValueError(f"No date provided to create a group with") + else: + raise ValueError(f"No name provided to create a group with") + + if hasattr(self, f"XR2_{piece_index}"): + if getattr(self, f"XR2_{piece_index}") is not None: + self.binary_instance.add_binary_params( + f"XR2_{piece_index}", getattr(self, f"XR2_{piece_index}") + ) + else: + raise ValueError(f"No date provided to create a group with") + else: + raise ValueError(f"No name provided to create a group with") + + self.update_binary_object(None) + + def validate(self): + """Include catches for overlapping groups. etc + Raises + ------ + ValueError + if there are pieces with no associated boundaries (T0X_0000 does not have a corresponding XR1_0000/XR2_0000) + ValueError + if any boundaries overlap (as it makes TOA assignment to a single group ambiguous). i.e. XR1_0000XR1_0001 + ValueError + if the number of lower and upper bounds don't match (should only be raised if XR1 is defined without XR2 and validate is run or vice versa) + """ + super().validate() + for p in ("T0", "A1"): + if getattr(self, p).value is None: + raise MissingParameter("BT", p, "%s is required for BT" % p) + + # If any *DOT is set, we need T0 + for p in ("PBDOT", "OMDOT", "EDOT", "A1DOT"): + if getattr(self, p).value is None: + getattr(self, p).set("0") + getattr(self, p).frozen = True + + if getattr(self, p).value is not None: + if self.T0.value is None: + raise MissingParameter("BT", "T0", "T0 is required if *DOT is set") + + if self.GAMMA.value is None: + self.GAMMA.set("0") + self.GAMMA.frozen = True + + dct_plb = self.get_prefix_mapping_component("XR1_") + dct_pub = self.get_prefix_mapping_component("XR2_") + dct_T0X = self.get_prefix_mapping_component("T0X_") + dct_A1X = self.get_prefix_mapping_component("A1X_") + if len(dct_plb) > 0 and len(dct_pub) > 0: + ls_plb = list(dct_plb.items()) + ls_pub = list(dct_pub.items()) + ls_T0X = list(dct_T0X.items()) + ls_A1X = list(dct_A1X.items()) + + j_plb = [((tup[1]).split("_"))[1] for tup in ls_plb] + j_pub = [((tup[1]).split("_"))[1] for tup in ls_pub] + j_T0X = [((tup[1]).split("_"))[1] for tup in ls_T0X] + j_A1X = [((tup[1]).split("_"))[1] for tup in ls_A1X] + + if j_plb != j_pub: + raise ValueError( + f"Group boundary mismatch error. Number of detected lower bounds: {j_plb}. Number of detected upper bounds: {j_pub}" + ) + if len(np.setdiff1d(j_plb, j_pub)) > 0: + raise ValueError( + f"Group index mismatch error. Check the indexes of XR1_/XR2_ parameters in the model" + ) + if not len(ls_A1X) > 0: + if len(ls_pub) > 0 and len(ls_T0X) > 0: + if len(np.setdiff1d(j_pub, j_T0X)) > 0: + raise ValueError( + f"Group index mismatch error. Check the indexes of T0X groups, make sure they match there are corresponding group ranges (XR1/XR2)" + ) + if not len(ls_T0X) > 0: + if len(ls_pub) > 0 and len(ls_A1X) > 0: + if len(np.setdiff1d(j_pub, j_A1X)) > 0: + raise ValueError( + f"Group index mismatch error. Check the indexes of A1X groups, make sure they match there are corresponding group ranges (/XR2)" + ) + lb = [(getattr(self, tup[1])).value for tup in ls_plb] + ub = [(getattr(self, tup[1])).value for tup in ls_pub] + + for i in range(len(lb)): + for j in range(len(lb)): + if i != j: + if max(lb[i], lb[j]) < min(ub[i], ub[j]): + raise ValueError( + f"Group boundary overlap detected. Make sure groups are not overlapping" + ) + + def paramx_per_toa(self, param_name, toas): + """Find the piecewise parameter value each toa will reference during calculations + Parameters + ---------- + param_name : string + which piecewise parameter to show: 'A1'/'T0'. TODO this should raise an error if not present) + toa : pint.toa.TOA + Returns + ------- + u.quantity.Quantity + length(toa) elements are T0X or A1X values to reference for each toa during binary calculations. + """ + condition = {} + tbl = toas.table + XR1_mapping = self.get_prefix_mapping_component("XR1_") + XR2_mapping = self.get_prefix_mapping_component("XR2_") + if not hasattr(self, "toas_selector"): + self.toas_selector = TOASelect(is_range=True) + if param_name[0:2] == "T0": + paramX_mapping = self.get_prefix_mapping_component("T0X_") + param_unit = u.d + elif param_name[0:2] == "A1": + paramX_mapping = self.get_prefix_mapping_component("A1X_") + param_unit = ls + else: + raise AttributeError( + "param '%s' not found. Please choose another. Currently implemented: 'T0' or 'A1' " + % param_name + ) + for piece_index in paramX_mapping.keys(): + r1 = getattr(self, XR1_mapping[piece_index]).quantity + r2 = getattr(self, XR2_mapping[piece_index]).quantity + condition[paramX_mapping[piece_index]] = (r1.mjd, r2.mjd) + select_idx = self.toas_selector.get_select_index(condition, tbl["mjd_float"]) + paramx = np.zeros(len(tbl)) * param_unit + for k, v in select_idx.items(): + paramx[v] += getattr(self, k).quantity + for i in range(len(paramx)): + if paramx[i] == 0: + paramx[i] = (getattr(self, param_name[0:2])).value * param_unit + + return paramx + + def get_number_of_groups(self): + """Get the number of piecewise parameters""" + return len(self.binary_instance.piecewise_parameter_information) + + def which_group_is_toa_in(self, toa): + """Find the group a toa belongs to based on the boundaries of groups passed to BT_piecewise + Parameters + ---------- + Returns + ------- + list + str elements, look like ['0000','0001'] for two TOAs where one refences T0X/A1X. + """ + # if isinstance(toa, pint.toa.TOAs): + # pass + # else: + # raise TypeError(f'toa must be a Time or pint.toa.TOAs - not {type(toa)}') + + tbl = toa.table + condition = {} + XR1_mapping = self.get_prefix_mapping_component("XR1_") + XR2_mapping = self.get_prefix_mapping_component("XR2_") + if not hasattr(self, "toas_selector"): + self.toas_selector = TOASelect(is_range=True) + boundaries = {} + for piece_index in XR1_mapping.keys(): + r1 = getattr(self, XR1_mapping[piece_index]).quantity + r2 = getattr(self, XR2_mapping[piece_index]).quantity + condition[(XR1_mapping[piece_index]).split("_")[-1]] = (r1.mjd, r2.mjd) + select_idx = self.toas_selector.get_select_index(condition, tbl["mjd_float"]) + paramx = np.empty(len(tbl), dtype=">import astropy.units as u + >>import numpy as np + + >>binary_model=BTpiecewise() + >>param_dict = {'T0': 50000, 'ECC': 0.2} + >>binary_model.update_input(**param_dict) + + >>t=np.linspace(50001.,60000.,10)*u.d + + Adding binary parameters and piece ranges + >>binary_model.add_binary_params('T0X_0000', 60000*u.d) + >>binary_model.add_binary_params('XR1_0000', 50000*u.d) + >>binary_model.add_binary_params('XR2_0000', 55000*u.d) + + Can add more pieces here... + + Overide default values values if desired + >>updates = {'T0X_0000':60000.*u.d,'XR1_0000':50000.*u.d,'XR2_0000': 55000*u.d} + + update the model with the piecewise parameter value(s) and piece ranges + >>binary_model.update_input(**updates) + + Using pint's get_model and loading this as a timing model and following the method described in ../binary_piecewise.py + sets _t multiple times during pint's residual calculation + for simplicity we're just going to set _t directly though this is not recommended. + >>setattr(binary_model,'_t' ,t) + + #here we call get_tt0 to get the "loaded toas" to interact with the pieces passed to the model earlier + #sets the attribute "T0X_per_toa" and/or "A1X_per_toa", contains the piecewise parameter value that will be referenced + #for each toa future calculations + >>binary_model.get_tt0(t) + #For a piecewise T0, tt0 becomes a piecewise quantity, otherwise it is how it functions in BT_model.py. + + #get_tt0 sets the attribute "T0X_per_toa" and/or "A1X_per_toa". + #contains the piecewise parameter value that will be referenced for each toa future calculations + >>binary_model.T0X_per_toa + + Information about any group can be found with the following: + >>binary_model.piecewise_parameter_information + Order: [[Group index, Piecewise T0, Piecewise A1, Piece lower bound, Piece upper bound]] + + Making sure a binary_model.tt0 exists + >>binary_model._tt0 = binary_model.get_tt0(binary_model._t) + + Obtain piecewise BTdelay() + >>binary_model.BTdelay() + """ + + def __init__(self, axis_store_initial=None, t=None, input_params=None): + self.binary_name = "BT_piecewise" + super(BTpiecewise, self).__init__() + if t is None: + self._t = None + self.axis_store_initial = [] + self.extended_group_range = [] + self.param_pieces = [] + self.d_binarydelay_d_par_funcs = [self.d_BTdelay_d_par] + if t is not None: + self._t = t + if input_params is not None: + if self.T0X is None: + self.update_input(input_params) + self.binary_params = list(self.param_default_value.keys()) + + def set_param_values(self, valDict=None): + super().set_param_values(valDict=valDict) + self.setup_internal_structures(valDict=valDict) + + def setup_internal_structures(self, valDict=None): + # initialise arrays to store T0X/A1X values per toa + self.T0X_arr = [] + self.A1X_arr = [] + # initialise arrays to store piecewise group boundaries + self.lower_group_edge = [] + self.upper_group_edge = [] + # initialise array that will be 5 x n in shape. Where n is the number of pieces required by the model + piecewise_parameter_information = [] + # If there are no updates passed by binary_instance, sets default value (usually overwritten when reading from parfile) + + if valDict is None: + self.T0X_arr = [self.T0] + self.A1X_arr = [self.A1] + self.lower_group_edge = [0] + self.upper_group_edge = [1e9] + self.piecewise_parameter_information = [ + 0, + self.T0, + self.A1, + 0 * u.d, + 1e9 * u.d, + ] + else: + # initialise array used to count the number of pieces. Operates by seaching for "A1X_i/T0X_i" and appending i to the array. + piece_index = [] + # Searches through updates for keys prefixes matching T0X/A1X, can be allowed to be more flexible with param+"X_" provided param is defined earlier. + for key, value in valDict.items(): + if ( + key[0:4] == "T0X_" + or key[0:4] == "A1X_" + and not (key[4:8] in piece_index) + ): + # appends index to array + piece_index.append((key[4:8])) + # makes sure only one instance of each index is present returns order indeces + piece_index = np.unique(piece_index) + # looping through each index in order they are given (0 -> n) + for index in piece_index: + # array to store specific piece i's information in the order [index,T0X,A1X,Group's lower edge, Group's upper edge,] + param_pieces = [] + piece_number = f"{int(index):04d}" + param_pieces.append(piece_number) + string = [ + "T0X_" + index, + "A1X_" + index, + "XR1_" + index, + "XR2_" + index, + ] + + # if string[0] not in param_pieces: + for i in range(0, len(string)): + if string[i] in valDict: + param_pieces.append(valDict[string[i]]) + elif string[i] not in valDict: + attr = string[i][0:2] + + if hasattr(self, attr): + param_pieces.append(getattr(self, attr)) + else: + raise AttributeError( + "Malformed valDict being used, attempting to set an attribute that doesn't exist. Likely a corner case slipping through validate() in binary_piecewise." + ) + # Raises error if range not defined as there is no Piece upper/lower bound in the model. + + piecewise_parameter_information.append(param_pieces) + + self.valDict = valDict + # sorts the array chronologically by lower edge of each group,correctly works for unordered pieces + + self.piecewise_parameter_information = sorted( + piecewise_parameter_information, key=lambda x: x[3] + ) + + # Uses the index for each toa array to create arrays where elements are the A1X/T0X to use with that toa + if len(self.piecewise_parameter_information) > 0: + if self._t is not None: + self.group_index_array = self.toa_belongs_in_group(self._t) + + ( + self.T0X_per_toa, + self.A1X_per_toa, + ) = self.piecewise_parameter_from_information_array(self._t) + + def piecewise_parameter_from_information_array(self, t): + """Creates a list of piecewise orbital parameters to use in calculations. It is the same dimensions as the TOAs loaded in. Each entry is the piecewise parameter value from the group it belongs to. + ---------- + t : Quantity. TOA, not necesserily barycentered + Returns + ------- + list + Quantity (length: t). T0X parameter to use for each TOA in calculations. + Quantity (length: t). A1X parameter to use for each TOA in calculations. + """ + A1X_per_toa = [] + T0X_per_toa = [] + if not hasattr(self, "group_index_array"): + self.group_index_array = self.toa_belongs_in_group(t) + if len(self.group_index_array) != len(t): + self.group_index_array = self.toa_belongs_in_group(t) + # searches the 5 x n array to find the index matching the toa_index + possible_groups = [item[0] for item in self.piecewise_parameter_information] + if len(self.group_index_array) > 1 and len(t) > 1: + for i in self.group_index_array: + if i != -1: + for k, j in enumerate(possible_groups): + if str(i) == j: + group_index = k + T0X_per_toa.append( + self.piecewise_parameter_information[group_index][ + 1 + ].value + ) + + A1X_per_toa.append( + self.piecewise_parameter_information[group_index][ + 2 + ].value + ) + + # if a toa lies between 2 groups, use default T0/A1 values (i.e. toa lies after previous upper bound but before next lower bound) + else: + T0X_per_toa.append(self.T0.value) + A1X_per_toa.append(self.A1.value) + + else: + T0X_per_toa = self.T0.value + A1X_per_toa = self.A1.value + + T0X_per_toa = T0X_per_toa * u.d + A1X_per_toa = A1X_per_toa * ls + + return [T0X_per_toa, A1X_per_toa] + + def toa_belongs_in_group(self, toas): + """Get the piece a TOA belongs to by finding which checking upper/lower edges of each piece. + ---------- + toas : Astropy.quantity.Quantity. + Returns + ------- + list + int (length: t). Group numbers + """ + group_no = [] + gb = self.get_group_boundaries() + + lower_edge = [] + upper_edge = [] + for i in range(len(gb[0])): + lower_edge.append(gb[0][i].value) + upper_edge.append(gb[1][i].value) + + # lower_edge, upper_edge = [self.get_group_boundaries()[:].value],[self.get_group_boundaries()[1].value] + for i in toas.value: + lower_bound = np.searchsorted(np.array(lower_edge), i) - 1 + upper_bound = np.searchsorted(np.array(upper_edge), i) + if lower_bound == upper_bound: + index_no = lower_bound + else: + index_no = -1 + if index_no != -1: + group_no.append(self.piecewise_parameter_information[index_no][0]) + else: + group_no.append(index_no) + return group_no + + def get_group_boundaries(self): + """Get the piecewise group boundaries from the dictionary of piecewise parameter information. + Returns + ------- + list + list (length: number of pieces). Contains all pieces' lower edge. + list (length: number of pieces). Contains all pieces' upper edge. + """ + lower_group_edge = [] + upper_group_edge = [] + if hasattr(self, "piecewise_parameter_information"): + for i in range(0, len(self.piecewise_parameter_information)): + lower_group_edge.append(self.piecewise_parameter_information[i][3]) + upper_group_edge.append(self.piecewise_parameter_information[i][4]) + return [lower_group_edge, upper_group_edge] + + def a1(self): + if len(self.piecewise_parameter_information) > 0: + # defines index for each toa as an array of length = len(self._t) + # Uses the index for each toa array to create arrays where elements are the A1X/T0X to use with that toa + self.A1X_per_toa = self.piecewise_parameter_from_information_array(self.t)[ + 1 + ] + + if hasattr(self, "A1X_per_toa"): + ret = self.A1X_per_toa + self.tt0 * self.A1DOT + else: + ret = self.A1 + self.tt0 * self.A1DOT + return ret + + def get_tt0(self, barycentricTOA): + """Finds (barycentricTOA - T0_x). Where T0_x is the piecewise T0 value, if it exists, correponding to the group the TOA belongs to. If T0_x does not exist, use the global T0 vlaue. + ---------- + Returns + ------- + astropy.quantity.Quantity + time since T0 + """ + if barycentricTOA is None or self.T0 is None: + return None + if len(barycentricTOA) > 1: + # defines index for each toa as an array of length = len(self._t) + # Uses the index for each toa array to create arrays where elements are the A1X/T0X to use with that toa + self.T0X_per_toa = self.piecewise_parameter_from_information_array( + barycentricTOA + )[0] + T0 = self.T0X_per_toa + else: + T0 = self.T0 + if not hasattr(barycentricTOA, "unit") or barycentricTOA.unit == None: + barycentricTOA = barycentricTOA * u.day + tt0 = (barycentricTOA - T0).to("second") + return tt0 + + def d_delayL1_d_par(self, par): + if par not in self.binary_params: + raise ValueError(f"{par} is not in binary parameter list.") + par_obj = getattr(self, par) + index, par_temp = self.in_piece(par) + if par_temp is None: + if hasattr(self, "d_delayL1_d_" + par): + func = getattr(self, "d_delayL1_d_" + par) + return func() * index + else: + if par in self.orbits_cls.orbit_params: + return self.d_delayL1_d_E() * self.d_E_d_par(par) + else: + return np.zeros(len(self.t)) * u.second / par_obj.unit + else: + if hasattr(self, "d_delayL1_d_" + par_temp): + func = getattr(self, "d_delayL1_d_" + par_temp) + return func() * index + else: + if par in self.orbits_cls.orbit_params: + return self.d_delayL1_d_E() * self.d_E_d_par() + else: + return np.zeros(len(self.t)) * u.second / par_obj.unit + + def d_delayL2_d_par(self, par): + if par not in self.binary_params: + raise ValueError(f"{par} is not in binary parameter list.") + par_obj = getattr(self, par) + index, par_temp = self.in_piece(par) + if par_temp is None: + if hasattr(self, "d_delayL2_d_" + par): + func = getattr(self, "d_delayL2_d_" + par) + return func() * index + else: + if par in self.orbits_cls.orbit_params: + return self.d_delayL2_d_E() * self.d_E_d_par(par) + else: + return np.zeros(len(self.t)) * u.second / par_obj.unit + else: + if hasattr(self, "d_delayL2_d_" + par_temp): + func = getattr(self, "d_delayL2_d_" + par_temp) + return func() * index + else: + if par in self.orbits_cls.orbit_params: + return self.d_delayL2_d_E() * self.d_E_d_par() + else: + return np.zeros(len(self.t)) * u.second / par_obj.unit + + def in_piece(self, par): + """Finds which TOAs reference which piecewise binary parameter group using the group_index_array property. + ---------- + par : str + Name of piecewise parameter e.g. 'T0X_0001' or 'A1X_0001' + Returns + ------- + list + boolean list (length: self._t). True where TOA references a given group, False otherwise. + binary piecewise parameter label str. e.g. 'T0X' or 'A1X'. + """ + if "_" in par: + text = par.split("_") + param = text[0] + toa_index = f"{int(text[1]):04d}" + else: + param = par + if hasattr(self, "group_index_array"): + # group_index_array should exist before fitting, constructing the model/residuals should add this(?) + group_indexes = np.array(self.group_index_array) + if param == "T0X": + ret = group_indexes == toa_index + return [ret, "T0X"] + elif param == "A1X": + ret = group_indexes == toa_index + return [ret, "A1X"] + # The toa_index = -1 corresponds to TOAs that don't reference any groups + else: + ret = group_indexes == -1 + return [ret, None] + #'None' corresponds to a parameter without a piecewise counterpart, so will effect all TOAs + else: + return [np.zeros(len(self._t)) + 1, None] + + def d_BTdelay_d_par(self, par): + return self.delayR() * (self.d_delayL2_d_par(par) + self.d_delayL1_d_par(par)) + + def d_delayL1_d_A1X(self): + return np.sin(self.omega()) * (np.cos(self.E()) - self.ecc()) / c.c + + def d_delayL2_d_A1X(self): + return ( + np.cos(self.omega()) * np.sqrt(1 - self.ecc() ** 2) * np.sin(self.E()) / c.c + ) + + def d_delayL1_d_T0X(self): + return self.d_delayL1_d_E() * self.d_E_d_T0X() + + def d_delayL2_d_T0X(self): + return self.d_delayL2_d_E() * self.d_E_d_T0X() + + def d_E_d_T0X(self): + """Analytic derivative + d(E-e*sinE)/dT0 = dM/dT0 + dE/dT0(1-cosE*e)-de/dT0*sinE = dM/dT0 + dE/dT0(1-cosE*e)+eDot*sinE = dM/dT0 + """ + RHS = self.prtl_der("M", "T0") + E = self.E() + EDOT = self.EDOT + ecc = self.ecc() + with u.set_enabled_equivalencies(u.dimensionless_angles()): + return (RHS - EDOT * np.sin(E)) / (1.0 - np.cos(E) * ecc) + + def prtl_der(self, y, x): + """Find the partial derivatives in binary model pdy/pdx + Parameters + ---------- + y : str + Name of variable to be differentiated + x : str + Name of variable the derivative respect to + Returns + ------- + np.array + The derivatives pdy/pdx + """ + if y not in self.binary_params + self.inter_vars: + errorMesg = y + " is not in binary parameter and variables list." + raise ValueError(errorMesg) + + if x not in self.inter_vars + self.binary_params: + errorMesg = x + " is not in binary parameters and variables list." + raise ValueError(errorMesg) + # derivative to itself + if x == y: + return np.longdouble(np.ones(len(self.tt0))) * u.Unit("") + # Get the unit right + yAttr = getattr(self, y) + xAttr = getattr(self, x) + U = [None, None] + for i, attr in enumerate([yAttr, xAttr]): + # If attr is a PINT Parameter class type + if hasattr(attr, "units"): + U[i] = attr.units + # If attr is a Quantity type + elif hasattr(attr, "unit"): + U[i] = attr.unit + # If attr is a method + elif hasattr(attr, "__call__"): + U[i] = attr().unit + else: + raise TypeError(type(attr) + "can not get unit") + yU = U[0] + xU = U[1] + # Call derivtive functions + derU = yU / xU + if hasattr(self, "d_" + y + "_d_" + x): + dername = "d_" + y + "_d_" + x + result = getattr(self, dername)() + elif hasattr(self, "d_" + y + "_d_par"): + dername = "d_" + y + "_d_par" + result = getattr(self, dername)(x) + else: + result = np.longdouble(np.zeros(len(self.tt0))) + if hasattr(result, "unit"): + return result.to(derU, equivalencies=u.dimensionless_angles()) + else: + return result * derU + + def d_M_d_par(self, par): + """derivative for M respect to bianry parameter. + Parameters + ---------- + par : string + parameter name + Returns + ------- + Derivitve of M respect to par + """ + if par not in self.binary_params: + errorMesg = par + " is not in binary parameter list." + raise ValueError(errorMesg) + par_obj = getattr(self, par) + result = self.orbits_cls.d_orbits_d_par(par) + with u.set_enabled_equivalencies(u.dimensionless_angles()): + result = result.to(u.Unit("") / par_obj.unit) + return result diff --git a/tests/test_BT_piecewise.py b/tests/test_BT_piecewise.py new file mode 100644 index 000000000..daf75143f --- /dev/null +++ b/tests/test_BT_piecewise.py @@ -0,0 +1,695 @@ +from pint.models import get_model +import pint.toa +import numpy as np +import pint.fitter +import astropy.units as u +from pint import ls +from copy import deepcopy +import pint.residuals +from astropy.time import Time +import pint.models.stand_alone_psr_binaries.BT_piecewise as BTpiecewise +import matplotlib.pyplot as plt +import unittest +from io import StringIO +from pylab import * +import pytest +from pint.simulation import make_fake_toas_uniform +import pint.logging + +pint.logging.setup(level="ERROR") + + +@pytest.fixture +def model_no_pieces(scope="session"): + # builds a J1023+0038-like model with no pieces + par_base = """ + PSR 1023+0038 + TRACK -3 + EPHEM DE421 + CLOCK TT(BIPM2019) + START 55000. + FINISH 55200. + DILATEFREQ N + RAJ 10:23:47.68719801 + DECJ 0:38:40.84551000 + POSEPOCH 54995. + F0 592. + F1 -2. + PEPOCH 55000. + PLANET_SHAPIRO N + DM 14. + BINARY BT_piecewise + PB 0.2 + PBDOT 0.0 + A1 0.34333468063634737 1 + A1DOT 0.0 + ECC 0.0 + EDOT 0.0 + T0 55000. + TZRMJD 55000. + TZRSITE 1 + """ + model = get_model(StringIO(par_base)) + # lurking bug: T0X_0000/A1X_0000 and boundaries are not automatically deleted on intialisation + model.remove_range(0) + return model + + +@pytest.fixture +def model_BT(): + # builds a J1023+0038-like model with no pieces + par_base = """ + PSR 1023+0038 + TRACK -3 + EPHEM DE421 + CLOCK TT(BIPM2019) + START 55000. + FINISH 55200. + DILATEFREQ N + RAJ 10:23:47.68719801 + DECJ 0:38:40.84551000 + POSEPOCH 54995. + F0 592. + F1 -2. + PEPOCH 55000. + PLANET_SHAPIRO N + DM 14. + BINARY BT + PB 0.2 + PBDOT 0.0 + A1 0.34333468063634737 1 + A1DOT 0.0 + ECC 0.0 + EDOT 0.0 + T0 55000. + TZRMJD 55000. + TZRSITE 1 + """ + model = get_model(StringIO(par_base)) + return model + + +@pytest.fixture() +def build_piecewise_model_with_one_A1_piece(model_no_pieces): + # takes the basic model frame and adds 2 non-ovelerlapping pieces to it + piecewise_model = deepcopy(model_no_pieces) + lower_bound = [55000] + upper_bound = [55100] + piecewise_model.add_group_range(lower_bound[0], upper_bound[0], piece_index=0) + piecewise_model.add_piecewise_param( + A1=piecewise_model.A1.value + 1.0e-3, piece_index=0 + ) + return piecewise_model + + +@pytest.fixture() +def build_piecewise_model_with_one_T0_piece(model_no_pieces): + # takes the basic model frame and adds 2 non-ovelerlapping pieces to it + piecewise_model = deepcopy(model_no_pieces) + lower_bound = [55000] + upper_bound = [55100] + piecewise_model.add_group_range(lower_bound[0], upper_bound[0], piece_index=0) + piecewise_model.add_piecewise_param( + T0=piecewise_model.T0.value + 1.0e-5, piece_index=0 + ) + return piecewise_model + + +# fine function +@pytest.fixture() +def build_piecewise_model_with_two_pieces(model_no_pieces): + # takes the basic model frame and adds 2 non-ovelerlapping pieces to it + piecewise_model = model_no_pieces + lower_bound = [55000, 55100.000000001] + upper_bound = [55100, 55200] + for i in range(len(lower_bound)): + piecewise_model.add_group_range(lower_bound[i], upper_bound[i], piece_index=i) + piecewise_model.add_piecewise_param( + A1=(piecewise_model.A1.value + (i + 1) * 1e-3) * ls, piece_index=i + ) + piecewise_model.add_piecewise_param( + T0=(piecewise_model.T0.value + (i + 1) * 1e-3) * u.d, piece_index=i + ) + return piecewise_model + + +# fine function +@pytest.fixture() +def make_toas_to_go_with_two_piece_model(build_piecewise_model_with_two_pieces): + # makes toas to go with the two non-overlapping, complete coverage model + m_piecewise = build_piecewise_model_with_two_pieces + lower_bound = [55000, 55100.00001] + upper_bound = [55100, 55200] + toas = make_fake_toas_uniform( + lower_bound[0] + 1, upper_bound[1] - 1, 20, m_piecewise + ) # slightly within group edges to make toas unambiguously contained within groups + return toas + + +# fine function +def add_full_coverage_and_non_overlapping_groups_and_make_toas( + model_no_pieces, build_piecewise_model_with_two_pieces +): + # function to build the models for specific edge cases i.e. distinct groups where all toas fit exactly within a group + model = build_piecewise_model_with_two_pieces + toas = make_generic_toas(model, 55001, 55199) + return model, toas + + +# fine function +def add_partial_coverage_groups_and_make_toas(build_piecewise_model_with_two_pieces): + # function to build the models for specific edge cases i.e. if all toas don't fit exactly within any groups + model3 = build_piecewise_model_with_two_pieces + # known bug: if A1X exists it needs a partner T0X otherwise it breaks can freeze T0X for the time being, just needs a little thought + model3.remove_range(0) + # make sure TOAs are within ranges + toas = make_generic_toas(model3, 55001, 55199) + return model3, toas + + +# fine function +def make_generic_toas(model, lower_bound, upper_bound): + # makes toas to go with the edge cases + return make_fake_toas_uniform(lower_bound, upper_bound, 20, model) + + +def add_offset_in_model_parameter(indexes, param, model): + m_piecewise_temp = deepcopy(model) + parameter_string = f"{param}_{int(indexes):04d}" + if hasattr(m_piecewise_temp, parameter_string): + delta = getattr(m_piecewise_temp, parameter_string).value + 1e-5 + getattr(m_piecewise_temp, parameter_string).value = delta + m_piecewise_temp.setup() + else: + parameter_string = param[0:2] + getattr(m_piecewise_temp, parameter_string).value = ( + getattr(m_piecewise_temp, parameter_string).value + 1e-5 + ) + m_piecewise_temp.setup() + return m_piecewise_temp + + +def add_relative_offset_for_derivatives(index, param, model, offset_size, plus=True): + m_piecewise_temp = deepcopy(model) + parameter_string = f"{param}_{int(index):04d}" + offset_size = offset_size.value + if plus is True: + if hasattr(m_piecewise_temp, parameter_string): + delta = getattr(m_piecewise_temp, parameter_string).value + offset_size + getattr(m_piecewise_temp, parameter_string).value = delta + else: + if hasattr(m_piecewise_temp, parameter_string): + delta = getattr(m_piecewise_temp, parameter_string).value - offset_size + getattr(m_piecewise_temp, parameter_string).value = delta + return m_piecewise_temp + + +# fine function +def test_round_trips_to_parfile(model_no_pieces): + # test: see if the model can be reproduced after piecewise parameters have been added, + # checks by comparing parameter keys in both the old and new file. Should have the number of matches = number of parameters + m_piecewise = model_no_pieces + n = 10 + lower_bounds = [ + 55050, + 55101, + 55151, + 55201, + 55251, + 55301, + 55351, + 55401, + 55451, + 55501, + ] + upper_bounds = [ + 55100, + 55150, + 55200, + 55250, + 55300, + 55350, + 55400, + 55450, + 55500, + 55550, + ] + for i in range(0, n): + m_piecewise.add_group_range(lower_bounds[i], upper_bounds[i], piece_index=i) + m_piecewise.add_piecewise_param( + A1=(m_piecewise.A1.value + i) * ls, piece_index=i + ) + m_piecewise.add_piecewise_param( + T0=(m_piecewise.T0.value + i) * u.d, piece_index=i + ) + m3 = get_model(StringIO(m_piecewise.as_parfile())) + param_dict = m_piecewise.get_params_dict(which="all") + copy_param_dict = m3.get_params_dict(which="all") + number_of_keys = 0 + n_keys_identified = 0 + n_values_preserved = 0 + comparison = 0 + for key, value in param_dict.items(): + number_of_keys = number_of_keys + 1 # iterates up to total number of keys + for copy_key, copy_value in copy_param_dict.items(): + if key == copy_key: # search both pars for identical keys + n_keys_identified = n_keys_identified + 1 + if type(value) == type(copy_value): + if value.value == copy_value.value: + n_values_preserved = n_values_preserved + 1 + assert n_keys_identified == number_of_keys + assert n_values_preserved == number_of_keys + + +# fine function +def test_get_number_of_groups(build_piecewise_model_with_two_pieces): + # test to make sure number of groups matches with number of added piecewise parameters + m_piecewise = build_piecewise_model_with_two_pieces + number_of_groups = m_piecewise.get_number_of_groups() + assert number_of_groups == 2 + + +# fine function +def test_group_assignment_toas_unambiguously_within_group( + build_piecewise_model_with_two_pieces, make_toas_to_go_with_two_piece_model +): + # test to see if the group, for one toa per group, that the BT_piecewise.print_toas_in_group functions as intended. + # operates by sorting the toas by MJD compared against a groups upper/lower edge. + # operates with np.searchsorted so for 1 toa per group, each toa should be uniquely indexed after/before the lower/upper edge + model = build_piecewise_model_with_two_pieces + index = model.which_group_is_toa_in(make_toas_to_go_with_two_piece_model) + should_be_ten_toas_in_each_group = [ + np.unique(index, return_counts=True)[1][0], + np.unique(index, return_counts=True)[1][1], + ] + expected_toas_in_each_group = [10, 10] + is_there_ten_toas_per_group = np.testing.assert_array_equal( + should_be_ten_toas_in_each_group, expected_toas_in_each_group + ) + np.testing.assert_array_equal( + should_be_ten_toas_in_each_group, expected_toas_in_each_group + ) + + +# fine function +@pytest.mark.parametrize("param", ["A1X", "T0X"]) +def test_paramX_per_toa_matches_corresponding_model_value( + param, build_piecewise_model_with_two_pieces, make_toas_to_go_with_two_piece_model +): + # Testing the correct piecewise parameters are being assigned to each toa. + # Operates on the piecewise_parameter_from_information_array function. Requires group_index fn to be run so we have an array of length(ntoas), filled with information on which group a toa belongs to. + # Uses this array to apply T0X_i/A1X_i to corresponding indexes from group_index fn call. i.e. for T0X_i,T0X_j,T0X_k values and group_index return: [i,j,k] the output would be [T0X_i,T0X_j,T0X_k] + m_piecewise = build_piecewise_model_with_two_pieces + toa = make_toas_to_go_with_two_piece_model + expected_piece_1 = np.full(int(len(toa)), True) + expected_piece_1[int(len(toa) / 2) :] = False + + expected_piece_2 = np.full(int(len(toa)), True) + expected_piece_2[: int(len(toa) / 2)] = False + + should_toa_reference_piecewise_parameter = [expected_piece_1, expected_piece_2] + if param == "A1X": + paramX_per_toa = m_piecewise.paramx_per_toa("A1", toa) + test_val = [m_piecewise.A1X_0000.value, m_piecewise.A1X_0001.value] + + elif param == "T0X": + paramX_per_toa = m_piecewise.paramx_per_toa("T0", toa) + test_val = [m_piecewise.T0X_0000.value, m_piecewise.T0X_0001.value] + + do_toas_reference_first_piecewise_parameter = np.isclose( + (paramX_per_toa.value - test_val[0]), 0, atol=1e-6, rtol=0 + ) + + do_toas_reference_second_piecewise_parameter = np.isclose( + (paramX_per_toa.value - test_val[1]), 0, atol=1e-6, rtol=0 + ) + + do_toas_reference_piecewise_parameter = [ + do_toas_reference_first_piecewise_parameter, + do_toas_reference_second_piecewise_parameter, + ] + + np.testing.assert_array_equal( + do_toas_reference_piecewise_parameter, should_toa_reference_piecewise_parameter + ) + + +# fine function +def test_problematic_group_indexes_and_ranges(model_no_pieces): + # Test to flag issues with problematic group indexes + # Could fold this with the next test for a mega-check exceptions are raised test + m_piecewise = model_no_pieces + with pytest.raises(ValueError): + m_piecewise.add_group_range( + m_piecewise.START.value, m_piecewise.FINISH.value, piece_index=-1 + ) + with pytest.raises(ValueError): + m_piecewise.add_group_range( + m_piecewise.FINISH.value, m_piecewise.START.value, piece_index=1 + ) + with pytest.raises(ValueError): + m_piecewise.add_group_range( + m_piecewise.START.value, m_piecewise.FINISH.value, piece_index=1 + ) + m_piecewise.add_group_range( + m_piecewise.START.value, m_piecewise.FINISH.value, piece_index=1 + ) + with pytest.raises(ValueError): + m_piecewise.add_group_range( + m_piecewise.START.value, m_piecewise.FINISH.value, piece_index=10000 + ) + + +def test_group_index_matching(model_no_pieces): + m_piecewise = model_no_pieces + with pytest.raises(ValueError): + # should flag mismatching A1 group and boundary indexes + m_piecewise.add_group_range( + m_piecewise.START.value, m_piecewise.FINISH.value, piece_index=1 + ) + m_piecewise.add_piecewise_param(A1=m_piecewise.A1.value * ls, piece_index=2) + # Errors raised in validate, which is run when groups are "locked in" + m_piecewise.setup() + m_piecewise.validate() + with pytest.raises(ValueError): + # should flag mismatching T0 group and boundary indexes + m_piecewise.add_group_range( + m_piecewise.START.value, m_piecewise.FINISH.value, piece_index=1 + ) + m_piecewise.add_piecewise_param(T0=m_piecewise.T0.value * u.d, piece_index=2) + # Errors raised in validate, which is run when groups are "locked in" + m_piecewise.setup() + m_piecewise.validate() + with pytest.raises(ValueError): + # check whether boundaries are overlapping + m_piecewise.add_group_range(55000, 55200, piece_index=1) + m_piecewise.add_piecewise_param(A1=m_piecewise.A1.value * ls, piece_index=1) + m_piecewise.add_piecewise_param(T0=m_piecewise.T0.value * u.d, piece_index=1) + + m_piecewise.add_group_range(55100, 55300, piece_index=2) + m_piecewise.add_piecewise_param(A1=m_piecewise.A1.value * ls, piece_index=2) + m_piecewise.add_piecewise_param(T0=m_piecewise.T0.value * u.d, piece_index=2) + + m_piecewise.setup() + m_piecewise.validate() + with pytest.raises(ValueError): + # check whether boundaries are equal + m_piecewise.add_group_range(55000, 55000, piece_index=1) + m_piecewise.add_piecewise_param(A1=m_piecewise.A1.value * u.d, piece_index=1) + m_piecewise.add_piecewise_param(T0=m_piecewise.T0.value * u.d, piece_index=1) + m_piecewise.setup() + m_piecewise.validate() + + +@pytest.mark.parametrize( + "param, index", [("T0X", 0), ("T0X", 1), ("A1X", 0), ("A1X", 1)] +) +def test_residuals_in_groups_respond_to_changes_in_corresponding_piecewise_parameter( + model_no_pieces, build_piecewise_model_with_two_pieces, param, index +): + m_piecewise, toa = add_full_coverage_and_non_overlapping_groups_and_make_toas( + model_no_pieces, build_piecewise_model_with_two_pieces + ) + rs_value = pint.residuals.Residuals( + toa, m_piecewise, subtract_mean=False + ).resids_value + param_string = f"{param}_{int(index):04d}" + m_piecewise_temp = add_offset_in_model_parameter(index, param, m_piecewise) + if param == "A1X": + paramX_per_toa = m_piecewise.paramx_per_toa("A1", toa) + if param == "T0X": + paramX_per_toa = m_piecewise.paramx_per_toa("T0", toa) + + test_val = [getattr(m_piecewise, param_string).value] + + rs_temp = pint.residuals.Residuals( + toa, m_piecewise_temp, subtract_mean=False + ).resids_value + have_residuals_changed = rs_temp != rs_value + + are_toas_referencing_paramX = np.isclose( + (paramX_per_toa.value - test_val[0]), 0, atol=1e-6, rtol=0 + ) + + should_residuals_change = are_toas_referencing_paramX + + np.testing.assert_array_equal(have_residuals_changed, should_residuals_change) + + +@pytest.mark.parametrize( + "param, index", [("T0X", 0), ("T0X", 1), ("A1X", 0), ("A1X", 1)] +) +def test_d_delay_in_groups_respond_to_changes_in_corresponding_piecewise_parameter( + param, + index, + model_no_pieces, + build_piecewise_model_with_two_pieces, +): + m_piecewise, toa = add_full_coverage_and_non_overlapping_groups_and_make_toas( + model_no_pieces, build_piecewise_model_with_two_pieces + ) + # m_piecewise_temp = add_offset_in_model_parameter(index, param, m_piecewise) + + param_string = f"{param}_{int(index):04d}" + m_piecewise_temp = add_offset_in_model_parameter(index, param_string, m_piecewise) + if param == "A1X": + paramX_per_toa = m_piecewise_temp.paramx_per_toa("A1", toa) + + if param == "T0X": + paramX_per_toa = m_piecewise_temp.paramx_per_toa("T0", toa) + test_val = [getattr(m_piecewise, param_string).value] + are_toas_referencing_paramX = np.isclose( + (paramX_per_toa.value - test_val[0]), 0, atol=1e-6, rtol=0 + ) + + is_d_delay_changing = np.invert( + np.isclose( + m_piecewise_temp.d_binary_delay_d_xxxx(toa, param_string, None).value, + 0, + atol=1e-11, + rtol=0, + ) + ) + should_d_delay_be_changing = are_toas_referencing_paramX + # assert toas that are in the group have some non-zero delay derivative + np.testing.assert_array_equal(is_d_delay_changing, should_d_delay_be_changing) + + +@pytest.mark.parametrize("param, index", [("T0X", 0), ("A1X", 0)]) +def test_derivatives_in_pieces_are_same_as_BT_piecewise_paramx( + param, + index, + model_no_pieces, + model_BT, + build_piecewise_model_with_one_T0_piece, + build_piecewise_model_with_one_A1_piece, +): + if param == "A1X": + m_piecewise = build_piecewise_model_with_one_A1_piece + elif param == "T0X": + m_piecewise = build_piecewise_model_with_one_T0_piece + + m_non_piecewise = model_BT + toas = make_generic_toas(m_non_piecewise, 55001, 55199) + + param_string = f"{param}_{int(index):04d}" + param_q = getattr(m_non_piecewise, param[0:2]) + setattr(param_q, "value", getattr(m_piecewise, param_string).value) + + piecewise_delays = m_piecewise.d_binary_delay_d_xxxx( + toas, param_string, acc_delay=None + ) + non_piecewise_delays = m_non_piecewise.d_binary_delay_d_xxxx( + toas, param[0:2], acc_delay=None + ) + # gets which toas that should be changing + if param == "A1X": + paramX_per_toa = m_piecewise.paramx_per_toa("A1", toas) + + if param == "T0X": + paramX_per_toa = m_piecewise.paramx_per_toa("T0", toas) + test_val = [getattr(m_piecewise, param_string).value] + are_toas_referencing_paramX = np.isclose( + (paramX_per_toa.value - test_val[0]), 0, atol=1e-6, rtol=0 + ) + where_delays_should_change = are_toas_referencing_paramX + # checks the derivatives wrt T0X is the same as the derivative calc'd in the BT model for T0=T0X, for TOAs within that group + np.testing.assert_array_equal( + piecewise_delays[where_delays_should_change], + non_piecewise_delays[where_delays_should_change], + ) + # checks the derivatives wrt T0X are 0 for toas outside of the group + np.testing.assert_array_equal( + piecewise_delays[~where_delays_should_change], + np.zeros(len(piecewise_delays[~where_delays_should_change])), + ) + + +# This test is a bit of a mess, attempting to manipulate multiple models without breaking anything (i.e. model_1 and model_2 should not be affected by changes made to the other) +def test_interacting_with_multiple_models(model_no_pieces): + m_piecewise_1 = deepcopy(model_no_pieces) + m_piecewise_2 = deepcopy(model_no_pieces) + lower_bound = [55000, 55100.00001] + upper_bound = [55100, 55200] + # just check by creating the models and adding pieces we aren't adding things to the other model + m_piecewise_1.add_group_range(lower_bound[0], upper_bound[0], piece_index=0) + m_piecewise_1.add_piecewise_param(T0=m_piecewise_1.T0.value + 1.0e-3, piece_index=0) + m_piecewise_1.add_piecewise_param(A1=m_piecewise_1.A1.value + 1.0e-3, piece_index=0) + m_piecewise_1.setup() + m_piecewise_1.validate() + m_piecewise_2.add_group_range(lower_bound[1], upper_bound[1], piece_index=0) + m_piecewise_2.add_piecewise_param(T0=m_piecewise_2.T0.value + 3.0e-3, piece_index=0) + m_piecewise_2.add_piecewise_param(A1=m_piecewise_2.A1.value + 3.0e-3, piece_index=0) + m_piecewise_2.setup() + m_piecewise_2.validate() + # not yet interlacing function calls, just some extra sanity checks when it comes to loading more than one model that are yet untested + np.testing.assert_allclose(m_piecewise_1.XR1_0000.value, lower_bound[0]) + np.testing.assert_allclose(m_piecewise_1.XR2_0000.value, upper_bound[0]) + np.testing.assert_allclose(m_piecewise_2.XR1_0000.value, lower_bound[1]) + np.testing.assert_allclose(m_piecewise_2.XR2_0000.value, upper_bound[1]) + + np.testing.assert_allclose( + m_piecewise_1.T0X_0000.value, m_piecewise_1.T0.value + 1.0e-3 + ) + np.testing.assert_allclose( + m_piecewise_1.A1X_0000.value, m_piecewise_1.A1.value + 1.0e-3 + ) + np.testing.assert_allclose( + m_piecewise_2.T0X_0000.value, m_piecewise_2.T0.value + 3.0e-3 + ) + np.testing.assert_allclose( + m_piecewise_2.A1X_0000.value, m_piecewise_2.A1.value + 3.0e-3 + ) + + # just some arithmetic tests to see if they respond to changes for now. + # Need to find a way of listing which parameters are updated during common function calls. + # e.g. creating residuals, why do things like tt0 change to len(1) during the calculation? + # This is just designed to try and confuse the reader (its following a set of stable calculations and checking they match at intervals) + param_string_T0X = "T0X_0000" + param_string_A1X = "A1X_0000" + param_m1_T0 = getattr(m_piecewise_1, param_string_T0X) + param_m1_A1 = getattr(m_piecewise_1, param_string_A1X) + param_m2_T0 = getattr(m_piecewise_2, param_string_T0X) + param_m2_A1 = getattr(m_piecewise_2, param_string_A1X) + + setattr(param_m1_T0, "value", getattr(m_piecewise_2, param_string_T0X).value) + setattr(param_m1_A1, "value", getattr(m_piecewise_2, param_string_A1X).value) + setattr(param_m2_T0, "value", getattr(m_piecewise_2, "T0").value) + setattr(param_m2_A1, "value", getattr(m_piecewise_2, "A1").value) + + np.testing.assert_allclose( + m_piecewise_1.T0X_0000.value, m_piecewise_1.T0.value + 3.0e-3 + ) + np.testing.assert_allclose( + m_piecewise_1.A1X_0000.value, m_piecewise_1.A1.value + 3.0e-3 + ) + np.testing.assert_allclose(m_piecewise_2.T0X_0000.value, m_piecewise_2.T0.value) + np.testing.assert_allclose(m_piecewise_2.A1X_0000.value, m_piecewise_2.A1.value) + + setattr( + param_m1_T0, "value", getattr(m_piecewise_2, param_string_T0X).value + 6.0e-3 + ) + setattr( + param_m1_A1, "value", getattr(m_piecewise_2, param_string_A1X).value + 6.0e-3 + ) + setattr(param_m2_T0, "value", getattr(m_piecewise_2, "T0").value + 3.0e-3) + setattr(param_m2_A1, "value", getattr(m_piecewise_2, "A1").value + 3.0e-3) + + np.testing.assert_allclose( + m_piecewise_1.T0X_0000.value, m_piecewise_1.T0.value + 6.0e-3 + ) + np.testing.assert_allclose( + m_piecewise_1.A1X_0000.value, m_piecewise_1.A1.value + 6.0e-3 + ) + np.testing.assert_allclose( + m_piecewise_2.T0X_0000.value, m_piecewise_2.T0.value + 3.0e-3 + ) + np.testing.assert_allclose( + m_piecewise_2.A1X_0000.value, m_piecewise_2.A1.value + 3.0e-3 + ) + + # can add more suggested tests in here + + +# --Place here for future tests-- +# Wants to check the residuals within the group are the same as those from the model used to generate them which has T0 = T0X_0000 (the only group) +# In the residual calculation there is a mean subtraction/round off occurring in changing the binary parameter/something else. +# This means: test_residuals_in_pieces_are_same_as_BT_piecewise_*, would not have exactly equal residuals. +# In fact the TOAS that are expected to replicate the BT_model residuals are systematically delayed by an amount that is larger than the noise of uniform TOAs. +# Looks like the "noise" about this extra delay matches the noise of the uniform TOAs generated by the non-piecewise model +# Suggests: Don't leave TOAs in undeclared groups, try to cover the whole date range so a TOA lies in a group until explored further + + +# --WIP Tests-- +# This is a wip test to evaluate the residuals using TOAs generated from a non-piecewise model using the BT and BT piecewise model +# The test should pass if the residuals within a piece match those residuals produced by a BT model with the same parameter value as declared within the piece +# i.e. Use BT model to create TOAs with flat residuals -> adjust BT model parameter to match the param value declared within a piece of the piecewise model -> get the residuals of the piecewise and BT model for TOAs that exist within the piece, they should match. +# *Should* work but seems to be unable to produce fake TOAs when run through CI tests +# @pytest.mark.parametrize("param, index", [("T0X", 0), ("A1X", 0)]) +# def test_residuals_in_pieces_are_same_as_BT_piecewise_paramx( +# param, +# index, +# model_no_pieces, +# model_BT, +# build_piecewise_model_with_one_T0_piece, +# build_piecewise_model_with_one_A1_piece, +# ): +# if param == "A1X": +# m_piecewise = build_piecewise_model_with_one_A1_piece +# elif param == "T0X": +# m_piecewise = build_piecewise_model_with_one_T0_piece +# m_non_piecewise = model_BT + +# param_string = f"{param}_{int(index):04d}" +# param_q = getattr(m_non_piecewise, param[0:2]) +# setattr(param_q, "value", getattr(m_piecewise, param_string).value) +# toas = make_generic_toas(m_non_piecewise, 55001, 55099) +# rs_piecewise = pint.residuals.Residuals( +# toas, m_piecewise, subtract_mean=True, use_weighted_mean=False +# ).time_resids +# rs_non_piecewise = pint.residuals.Residuals( +# toas, m_non_piecewise, subtract_mean=True, use_weighted_mean=False +# ).time_resids +# np.testing.assert_allclose(rs_piecewise, rs_non_piecewise) +# +# +# This is a wip test to evaluate the TOA group allocation in the absence of "full-group coverage" (i.e. includes data that exists outside of pieces). +# i.e. Use (either BT/piecewise) model to create TOAs with flat residuals -> check the parameter value the TOAs reference during delay calculations (this should equal the global parameter value when there is no piecewise parameter to reference) +# *Should* work but seems to be unable to produce fake TOAs when run through CI tests +# @pytest.mark.parametrize("param", ["A1X", "T0X"]) +# def test_does_toa_lie_in_group_incomplete_group_coverage( +# param, model_no_pieces, build_piecewise_model_with_two_pieces +# ): +# m_piecewise, toa = add_partial_coverage_groups_and_make_toas(model_no_pieces) +# +# expected_out_piece = np.full(int(len(toa)), True) +# expected_out_piece[int(len(toa) / 2) :] = False +# +# expected_in_piece = np.full(int(len(toa)), True) +# expected_in_piece[: int(len(toa) / 2)] = False +# +# should_toa_reference_piecewise_parameter = [expected_in_piece, expected_out_piece] +# if param == "A1X": +# paramX_per_toa = m_piecewise.paramx_per_toa("A1", toa) +# test_val = [m_piecewise.A1.value, m_piecewise.A1X_0001.value] +# +# elif param == "T0X": +# paramX_per_toa = m_piecewise.paramx_per_toa("T0", toa) +# test_val = [m_piecewise.T0.value, m_piecewise.T0X_0001.value] +# +# are_toas_referencing_global_paramX = np.isclose( +# (paramX_per_toa.value - test_val[0]), 0, atol=1e-6, rtol=0 +# ) +# +# are_toas_referencing_piecewise_paramX = np.isclose( +# (paramX_per_toa.value - test_val[1]), 0, atol=1e-6, rtol=0 +# ) +# +# do_toas_reference_piecewise_parameter = [ +# are_toas_referencing_piecewise_paramX, +# are_toas_referencing_global_paramX, +# ] +# +# np.testing.assert_array_equal( +# do_toas_reference_piecewise_parameter, should_toa_reference_piecewise_parameter +# )