diff --git a/src/pint/models/binary_bt.py b/src/pint/models/binary_bt.py index 700e87470..cb4f50bbd 100644 --- a/src/pint/models/binary_bt.py +++ b/src/pint/models/binary_bt.py @@ -19,9 +19,6 @@ from pint.toa_select import TOASelect - - - class BinaryBT(PulsarBinary): """Blandford and Teukolsky binary model. @@ -88,17 +85,18 @@ def validate(self): See Blandford & Teukolsky 1976, ApJ, 205, 580. """ + class BinaryBTPiecewise(PulsarBinary): """Model implementing the BT model with piecewise orbital parameters A1X and T0X. This model lets the user specify time ranges and fit for a different piecewise orbital parameter in each time range, This is a PINT pulsar binary BTPiecewise model class, a subclass of PulsarBinary. - It is a wrapper for stand alone BTPiecewise class defined in + It is a wrapper for stand alone BTPiecewise class defined in ./stand_alone_psr_binary/BT_piecewise.py The aim for this class is to connect the stand alone binary model with the PINT platform. BTpiecewise special parameters, where xxxx denotes the 4-digit index of the piece: T0X_xxxx Piecewise T0 values for piece A1X_xxxx Piecewise A1 values for piece XR1_xxxx Lower time boundary of piece - XR2_xxxx Upper time boundary of piece + XR2_xxxx Upper time boundary of piece """ register = True @@ -119,10 +117,17 @@ def __init__(self): self.T0_value_funcs = [] self.remove_param("M2") self.remove_param("SINI") - - - def add_group_range(self, group_start_mjd, group_end_mjd, piece_index = None, ): - """Add an orbital piecewise parameter group range. If piece_index is not provided a new piece will be added with index equal to the number of pieces plus one. Pieces cannot have the duplicate pieces and cannot have the same index. A pair of consisting of a piecewise A1 and T0 may share an index and will act over the same piece range. + self.add_group_range(None, None) + self.add_piecewise_param(0, T0=0 * u.d) + self.add_piecewise_param(0, A1=0 * ls) + + def add_group_range( + self, + group_start_mjd, + group_end_mjd, + piece_index=None, + ): + """Add an orbital piecewise parameter group range. If piece_index is not provided a new piece will be added with index equal to the number of pieces plus one. Pieces cannot have the duplicate pieces and cannot have the same index. A pair of consisting of a piecewise A1 and T0 may share an index and will act over the same piece range. Parameters ---------- group_start_mjd : float or astropy.quantity.Quantity or astropy.time.Time @@ -133,25 +138,33 @@ def add_group_range(self, group_start_mjd, group_end_mjd, piece_index = None, ): Number to label the piece being added. """ if group_start_mjd is not None and group_end_mjd is not None: - if isinstance(group_start_mjd,Time): - group_start_mjd= group_start_mjd.mjd - elif isinstance(group_start_mjd,u.quantity.Quantity): - group_start_mjd= group_start_mjd.value - if isinstance(group_end_mjd,Time): + if isinstance(group_start_mjd, Time): + group_start_mjd = group_start_mjd.mjd + elif isinstance(group_start_mjd, u.quantity.Quantity): + group_start_mjd = group_start_mjd.value + if isinstance(group_end_mjd, Time): group_end_mjd = group_end_mjd.mjd - elif isinstance(group_end_mjd,u.quantity.Quantity): + elif isinstance(group_end_mjd, u.quantity.Quantity): group_end_mjd = group_end_mjd.value + + elif group_start_mjd is None or group_end_mjd is None: + if group_start_mjd is None and group_end_mjd is not None: + group_start_mjd = group_end_mjd - 100 + elif group_start_mjd is not None and group_end_mjd is None: + group_end_mjd = group_start_mjd + 100 + else: + group_start_mjd = 50000 + group_end_mjd = 60000 + if piece_index is None: dct = self.get_prefix_mapping_component("XR1_") - if len(list(dct.keys()))>0: - piece_index = np.max(list(dct.keys())) + 1 + if len(list(dct.keys())) > 0: + piece_index = np.max(list(dct.keys())) + 1 else: piece_index = 0 - # check the validity of the desired group to add - if group_end_mjd is not None and group_start_mjd is not None: if group_end_mjd <= group_start_mjd: raise ValueError("Starting MJD is greater than ending MJD.") @@ -163,8 +176,7 @@ def add_group_range(self, group_start_mjd, group_end_mjd, piece_index = None, ): raise ValueError( f"Invalid index for group. Cannot index beyond 9999 (yet?)" ) - - + i = f"{int(piece_index):04d}" self.add_param( prefixParameter( @@ -212,7 +224,7 @@ def remove_range(self, index): self.validate() self.setup() - def add_piecewise_param(self,piece_index,**kwargs ): + def add_piecewise_param(self, piece_index, **kwargs): """Add an orbital piecewise parameter. Parameters ---------- @@ -223,38 +235,40 @@ def add_piecewise_param(self,piece_index,**kwargs ): paramx : np.float128 or astropy.quantity.Quantity Piecewise parameter value. """ - for key in ('T0','A1'): + for key in ("T0", "A1"): if key in kwargs: param = key paramx = kwargs[key] - if isinstance(paramx,u.quantity.Quantity): + if isinstance(paramx, u.quantity.Quantity): paramx = paramx.value - elif isinstance(paramx,np.float128): + elif isinstance(paramx, np.float128): paramx = paramx else: raise ValueError( - "Unspported data type '%s'. Ensure the piecewise parameter value is a np.float128 or astropy.quantity.Quantity" % type(paramx) - ) - if key == 'T0': + "Unspported data type '%s'. Ensure the piecewise parameter value is a np.float128 or astropy.quantity.Quantity" + % type(paramx) + ) + if key == "T0": param_unit = u.d - elif key =='A1': + elif key == "A1": param_unit = ls key_found = True - + if not key_found: raise AttributeError( - "No piecewise parameters passed. Use T0 = / A1 = to declare a piecewise variable." - ) - + "No piecewise parameters passed. Use T0 = / A1 = to declare a piecewise variable." + ) + if piece_index is None: dct = self.get_prefix_mapping_component(param + "X_") - if len(list(dct.keys()))>0: - piece_index = np.max(list(dct.keys())) + 1 + if len(list(dct.keys())) > 0: + piece_index = np.max(list(dct.keys())) + 1 else: piece_index = 0 - elif int(piece_index) in self.get_prefix_mapping_component(param +"X_"): + elif int(piece_index) in self.get_prefix_mapping_component(param + "X_"): raise ValueError( - "Index '%s' is already in use in this model. Please choose another." % piece_index + "Index '%s' is already in use in this model. Please choose another." + % piece_index ) i = f"{int(piece_index):04d}" @@ -269,18 +283,17 @@ def add_piecewise_param(self,piece_index,**kwargs ): # check if name exits and is currently available self.add_param( - prefixParameter( - name=param + f"X_{i}", - units=param_unit, - value=paramx, - description="Parameter" + param + "variation", - parameter_type="float", - frozen=False, - ) - ) + prefixParameter( + name=param + f"X_{i}", + units=param_unit, + value=paramx, + description="Parameter" + param + "variation", + parameter_type="float", + frozen=False, + ) + ) self.setup() - def lock_groups(self): self.validate() self.update_binary_object(None) @@ -312,51 +325,50 @@ def setup(self): for index in XR1_mapping.values(): index = index.split("_")[1] - piece_index = f"{int(index):04d}" - if hasattr(self,f"T0X_{piece_index}"): - if getattr(self,f"T0X_{piece_index}") is not None: - self.binary_instance.add_binary_params(f"T0X_{piece_index}",getattr(self,f"T0X_{piece_index}")) - else: - self.binary_instance.add_binary_params(f"T0X_{piece_index}",self.T0.value) - - - if hasattr(self,f"A1X_{piece_index}"): - if hasattr(self,f"A1X_{piece_index}"): - if getattr(self,f"A1X_{piece_index}") is not None: - self.binary_instance.add_binary_params(f"A1X_{piece_index}",getattr(self,f"A1X_{piece_index}")) - else: - self.binary_instance.add_binary_params(f"A1X_{piece_index}",self.A1.value) - - - if hasattr(self,f"XR1_{piece_index}"): - if getattr(self,f"XR1_{piece_index}") is not None: - self.binary_instance.add_binary_params(f"XR1_{piece_index}", getattr(self,f"XR1_{piece_index}")) + piece_index = f"{int(index):04d}" + if hasattr(self, f"T0X_{piece_index}"): + if getattr(self, f"T0X_{piece_index}") is not None: + self.binary_instance.add_binary_params( + f"T0X_{piece_index}", getattr(self, f"T0X_{piece_index}") + ) else: - raise ValueError( - f"No date provided to create a group with" - ) + self.binary_instance.add_binary_params( + f"T0X_{piece_index}", self.T0.value + ) + + if hasattr(self, f"A1X_{piece_index}"): + if hasattr(self, f"A1X_{piece_index}"): + if getattr(self, f"A1X_{piece_index}") is not None: + self.binary_instance.add_binary_params( + f"A1X_{piece_index}", getattr(self, f"A1X_{piece_index}") + ) + else: + self.binary_instance.add_binary_params( + f"A1X_{piece_index}", self.A1.value + ) + + if hasattr(self, f"XR1_{piece_index}"): + if getattr(self, f"XR1_{piece_index}") is not None: + self.binary_instance.add_binary_params( + f"XR1_{piece_index}", getattr(self, f"XR1_{piece_index}") + ) + else: + raise ValueError(f"No date provided to create a group with") else: - raise ValueError( - f"No name provided to create a group with" - ) - - - - if hasattr(self,f"XR2_{piece_index}"): - if getattr(self,f"XR2_{piece_index}") is not None: - self.binary_instance.add_binary_params(f"XR2_{piece_index}", getattr(self,f"XR2_{piece_index}")) + raise ValueError(f"No name provided to create a group with") + + if hasattr(self, f"XR2_{piece_index}"): + if getattr(self, f"XR2_{piece_index}") is not None: + self.binary_instance.add_binary_params( + f"XR2_{piece_index}", getattr(self, f"XR2_{piece_index}") + ) else: - raise ValueError( - f"No date provided to create a group with" - ) + raise ValueError(f"No date provided to create a group with") else: - raise ValueError( - f"No name provided to create a group with" - ) + raise ValueError(f"No name provided to create a group with") self.update_binary_object(None) - def validate(self): """Include catches for overlapping groups. etc Raises @@ -431,20 +443,18 @@ def validate(self): f"Group boundary overlap detected. Make sure groups are not overlapping" ) - - - def paramx_per_toa(self,param_name,toas): + def paramx_per_toa(self, param_name, toas): condition = {} tbl = toas.table XR1_mapping = self.get_prefix_mapping_component("XR1_") XR2_mapping = self.get_prefix_mapping_component("XR2_") - if not hasattr(self,"toas_selector"): - self.toas_selector=TOASelect(is_range=True) - if param_name[0:2] == 'T0': + if not hasattr(self, "toas_selector"): + self.toas_selector = TOASelect(is_range=True) + if param_name[0:2] == "T0": paramX_mapping = self.get_prefix_mapping_component("T0X_") param_unit = u.d - elif param_name[0:2] == 'A1': - paramX_mapping = self.get_prefix_mapping_component("A1X_") + elif param_name[0:2] == "A1": + paramX_mapping = self.get_prefix_mapping_component("A1X_") param_unit = u.ls else: raise AttributeError( @@ -453,21 +463,20 @@ def paramx_per_toa(self,param_name,toas): ) for piece_index in paramX_mapping.keys(): - r1 = getattr(self,XR1_mapping[piece_index]).quantity - r2 = getattr(self,XR2_mapping[piece_index]).quantity - condition[paramX_mapping[piece_index]] = (r1.mjd,r2.mjd) - select_idx = self.toas_selector.get_select_index(condition,tbl["mjd_float"]) - paramx = np.zeros(len(tbl))*param_unit + r1 = getattr(self, XR1_mapping[piece_index]).quantity + r2 = getattr(self, XR2_mapping[piece_index]).quantity + condition[paramX_mapping[piece_index]] = (r1.mjd, r2.mjd) + select_idx = self.toas_selector.get_select_index(condition, tbl["mjd_float"]) + paramx = np.zeros(len(tbl)) * param_unit for k, v in select_idx.items(): - paramx[v]+=getattr(self,k).quantity + paramx[v] += getattr(self, k).quantity return paramx - - + def get_number_of_groups(self): """Get the number of piecewise parameters""" return len(self.binary_instance.piecewise_parameter_information) - #def get_group_boundaries(self): + # def get_group_boundaries(self): # """Get a all pieces' date boundaries. # Returns # ------- @@ -480,7 +489,7 @@ def get_number_of_groups(self): # # asks the object for the number of piecewise groups # return self.binary_instance.get_group_boundaries() - #def which_group_is_toa_in(self, toa): + # def which_group_is_toa_in(self, toa): # """Find the group a toa belongs to based on the boundaries of groups passed to BT_piecewise # Parameters # ---------- @@ -493,7 +502,7 @@ def get_number_of_groups(self): # """ # return self.binary_instance.toa_belongs_in_group(toa) - #def get_group_indexes(self): + # def get_group_indexes(self): # """Get all the piecewise parameter labels # Returns # ------- @@ -507,7 +516,7 @@ def get_number_of_groups(self): # ) # return group_indexes - #def get_T0Xs_associated_with_toas(self, toas): + # def get_T0Xs_associated_with_toas(self, toas): # """Get a of all the piecewise T0s associated with TOAs # Parameters # ---------- @@ -529,7 +538,7 @@ def get_number_of_groups(self): # self.binary_instance.group_index_array = temporary_storage # return T0X_per_toa - #def get_A1Xs_associated_with_toas(self, toas): + # def get_A1Xs_associated_with_toas(self, toas): # """Get a of all the piecewise A1s associated with TOAs # Parameters # ---------- @@ -551,7 +560,7 @@ def get_number_of_groups(self): # self.binary_instance.group_index_array = temporary_storage # return A1X_per_toa - #def does_toa_reference_piecewise_parameter(self, toas, param): + # def does_toa_reference_piecewise_parameter(self, toas, param): # """Query whether a TOA/list of TOAs belong(s) to a specific group # Parameters # ---------- @@ -567,5 +576,3 @@ def get_number_of_groups(self): # self.binary_instance.group_index_array = self.which_group_is_toa_in(toas) # from_in_piece = self.binary_instance.in_piece(param) # return from_in_piece[0] - - diff --git a/src/pint/models/stand_alone_psr_binaries/BT_piecewise.py b/src/pint/models/stand_alone_psr_binaries/BT_piecewise.py index 54280208a..d8ce307f8 100644 --- a/src/pint/models/stand_alone_psr_binaries/BT_piecewise.py +++ b/src/pint/models/stand_alone_psr_binaries/BT_piecewise.py @@ -48,8 +48,8 @@ class BTpiecewise(BTmodel): for simplicity we're just going to set _t directly though this is not recommended. >>setattr(binary_model,'_t' ,t) - #here we call get_tt0 to get the "loaded toas" to interact with the pieces passed to the model earlier - #sets the attribute "T0X_per_toa" and/or "A1X_per_toa", contains the piecewise parameter value that will be referenced + #here we call get_tt0 to get the "loaded toas" to interact with the pieces passed to the model earlier + #sets the attribute "T0X_per_toa" and/or "A1X_per_toa", contains the piecewise parameter value that will be referenced #for each toa future calculations >>binary_model.get_tt0(t) #For a piecewise T0, tt0 becomes a piecewise quantity, otherwise it is how it functions in BT_model.py. @@ -57,17 +57,18 @@ class BTpiecewise(BTmodel): #get_tt0 sets the attribute "T0X_per_toa" and/or "A1X_per_toa". #contains the piecewise parameter value that will be referenced for each toa future calculations >>binary_model.T0X_per_toa - + Information about any group can be found with the following: >>binary_model.piecewise_parameter_information Order: [[Group index, Piecewise T0, Piecewise A1, Piece lower bound, Piece upper bound]] - Making sure a binary_model.tt0 exists + Making sure a binary_model.tt0 exists >>binary_model._tt0 = binary_model.get_tt0(binary_model._t) Obtain piecewise BTdelay() >>binary_model.BTdelay() """ + def __init__(self, axis_store_initial=None, t=None, input_params=None): self.binary_name = "BT_piecewise" super(BTpiecewise, self).__init__() @@ -97,7 +98,6 @@ def setup_internal_structures(self, valDict=None): # initialise array that will be 5 x n in shape. Where n is the number of pieces required by the model piecewise_parameter_information = [] # If there are no updates passed by binary_instance, sets default value (usually overwritten when reading from parfile) - print(f"valDict {valDict}") if valDict is None: self.T0X_arr = [self.T0] self.A1X_arr = [self.A1] @@ -136,14 +136,13 @@ def setup_internal_structures(self, valDict=None): "XR1_" + index, "XR2_" + index, ] - if string[0] not in param_pieces: for i in range(0, len(string)): if string[i] in valDict: param_pieces.append(valDict[string[i]]) elif string[i] not in valDict: attr = string[i][0:2] - + if hasattr(self, attr): param_pieces.append(getattr(self, attr)) else: @@ -184,7 +183,7 @@ def piecewise_parameter_from_information_array(self, t): self.group_index_array = self.toa_belongs_in_group(t) # searches the 5 x n array to find the index matching the toa_index possible_groups = [item[0] for item in self.piecewise_parameter_information] - + for i in self.group_index_array: if i != -1: for k, j in enumerate(possible_groups): @@ -219,10 +218,10 @@ def toa_belongs_in_group(self, t): lower_edge = [] upper_edge = [] for i in range(len(gb[0])): - lower_edge.append( gb[0][i].value) - upper_edge.append( gb[1][i].value) - - #lower_edge, upper_edge = [self.get_group_boundaries()[:].value],[self.get_group_boundaries()[1].value] + lower_edge.append(gb[0][i].value) + upper_edge.append(gb[1][i].value) + + # lower_edge, upper_edge = [self.get_group_boundaries()[:].value],[self.get_group_boundaries()[1].value] for i in t.value: lower_bound = np.searchsorted(np.array(lower_edge), i) - 1 upper_bound = np.searchsorted(np.array(upper_edge), i) @@ -480,4 +479,3 @@ def d_M_d_par(self, par): with u.set_enabled_equivalencies(u.dimensionless_angles()): result = result.to(u.Unit("") / par_obj.unit) return result -