Skip to content

Commit

Permalink
Blacken code plus minor fixes to last change
Browse files Browse the repository at this point in the history
  • Loading branch information
poneill129 committed Sep 9, 2023
1 parent f1132d4 commit bce039c
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 122 deletions.
225 changes: 116 additions & 109 deletions src/pint/models/binary_bt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
from pint.toa_select import TOASelect





class BinaryBT(PulsarBinary):
"""Blandford and Teukolsky binary model.
Expand Down Expand Up @@ -88,17 +85,18 @@ def validate(self):
See Blandford & Teukolsky 1976, ApJ, 205, 580.
"""


class BinaryBTPiecewise(PulsarBinary):
"""Model implementing the BT model with piecewise orbital parameters A1X and T0X. This model lets the user specify time ranges and fit for a different piecewise orbital parameter in each time range,
This is a PINT pulsar binary BTPiecewise model class, a subclass of PulsarBinary.
It is a wrapper for stand alone BTPiecewise class defined in
It is a wrapper for stand alone BTPiecewise class defined in
./stand_alone_psr_binary/BT_piecewise.py
The aim for this class is to connect the stand alone binary model with the PINT platform.
BTpiecewise special parameters, where xxxx denotes the 4-digit index of the piece:
T0X_xxxx Piecewise T0 values for piece
A1X_xxxx Piecewise A1 values for piece
XR1_xxxx Lower time boundary of piece
XR2_xxxx Upper time boundary of piece
XR2_xxxx Upper time boundary of piece
"""

register = True
Expand All @@ -119,10 +117,17 @@ def __init__(self):
self.T0_value_funcs = []
self.remove_param("M2")
self.remove_param("SINI")


def add_group_range(self, group_start_mjd, group_end_mjd, piece_index = None, ):
"""Add an orbital piecewise parameter group range. If piece_index is not provided a new piece will be added with index equal to the number of pieces plus one. Pieces cannot have the duplicate pieces and cannot have the same index. A pair of consisting of a piecewise A1 and T0 may share an index and will act over the same piece range.
self.add_group_range(None, None)
self.add_piecewise_param(0, T0=0 * u.d)
self.add_piecewise_param(0, A1=0 * ls)

def add_group_range(
self,
group_start_mjd,
group_end_mjd,
piece_index=None,
):
"""Add an orbital piecewise parameter group range. If piece_index is not provided a new piece will be added with index equal to the number of pieces plus one. Pieces cannot have the duplicate pieces and cannot have the same index. A pair of consisting of a piecewise A1 and T0 may share an index and will act over the same piece range.
Parameters
----------
group_start_mjd : float or astropy.quantity.Quantity or astropy.time.Time
Expand All @@ -133,25 +138,33 @@ def add_group_range(self, group_start_mjd, group_end_mjd, piece_index = None, ):
Number to label the piece being added.
"""
if group_start_mjd is not None and group_end_mjd is not None:
if isinstance(group_start_mjd,Time):
group_start_mjd= group_start_mjd.mjd
elif isinstance(group_start_mjd,u.quantity.Quantity):
group_start_mjd= group_start_mjd.value
if isinstance(group_end_mjd,Time):
if isinstance(group_start_mjd, Time):
group_start_mjd = group_start_mjd.mjd
elif isinstance(group_start_mjd, u.quantity.Quantity):
group_start_mjd = group_start_mjd.value
if isinstance(group_end_mjd, Time):
group_end_mjd = group_end_mjd.mjd
elif isinstance(group_end_mjd,u.quantity.Quantity):
elif isinstance(group_end_mjd, u.quantity.Quantity):
group_end_mjd = group_end_mjd.value

elif group_start_mjd is None or group_end_mjd is None:
if group_start_mjd is None and group_end_mjd is not None:
group_start_mjd = group_end_mjd - 100
elif group_start_mjd is not None and group_end_mjd is None:
group_end_mjd = group_start_mjd + 100
else:
group_start_mjd = 50000
group_end_mjd = 60000

if piece_index is None:
dct = self.get_prefix_mapping_component("XR1_")
if len(list(dct.keys()))>0:
piece_index = np.max(list(dct.keys())) + 1
if len(list(dct.keys())) > 0:
piece_index = np.max(list(dct.keys())) + 1
else:
piece_index = 0


# check the validity of the desired group to add


if group_end_mjd is not None and group_start_mjd is not None:
if group_end_mjd <= group_start_mjd:
raise ValueError("Starting MJD is greater than ending MJD.")
Expand All @@ -163,8 +176,7 @@ def add_group_range(self, group_start_mjd, group_end_mjd, piece_index = None, ):
raise ValueError(
f"Invalid index for group. Cannot index beyond 9999 (yet?)"
)



i = f"{int(piece_index):04d}"
self.add_param(
prefixParameter(
Expand Down Expand Up @@ -212,7 +224,7 @@ def remove_range(self, index):
self.validate()
self.setup()

def add_piecewise_param(self,piece_index,**kwargs ):
def add_piecewise_param(self, piece_index, **kwargs):
"""Add an orbital piecewise parameter.
Parameters
----------
Expand All @@ -223,38 +235,40 @@ def add_piecewise_param(self,piece_index,**kwargs ):
paramx : np.float128 or astropy.quantity.Quantity
Piecewise parameter value.
"""
for key in ('T0','A1'):
for key in ("T0", "A1"):
if key in kwargs:
param = key
paramx = kwargs[key]
if isinstance(paramx,u.quantity.Quantity):
if isinstance(paramx, u.quantity.Quantity):
paramx = paramx.value
elif isinstance(paramx,np.float128):
elif isinstance(paramx, np.float128):
paramx = paramx
else:
raise ValueError(
"Unspported data type '%s'. Ensure the piecewise parameter value is a np.float128 or astropy.quantity.Quantity" % type(paramx)
)
if key == 'T0':
"Unspported data type '%s'. Ensure the piecewise parameter value is a np.float128 or astropy.quantity.Quantity"
% type(paramx)
)
if key == "T0":
param_unit = u.d
elif key =='A1':
elif key == "A1":
param_unit = ls
key_found = True

if not key_found:
raise AttributeError(
"No piecewise parameters passed. Use T0 = / A1 = to declare a piecewise variable."
)
"No piecewise parameters passed. Use T0 = / A1 = to declare a piecewise variable."
)

if piece_index is None:
dct = self.get_prefix_mapping_component(param + "X_")
if len(list(dct.keys()))>0:
piece_index = np.max(list(dct.keys())) + 1
if len(list(dct.keys())) > 0:
piece_index = np.max(list(dct.keys())) + 1
else:
piece_index = 0
elif int(piece_index) in self.get_prefix_mapping_component(param +"X_"):
elif int(piece_index) in self.get_prefix_mapping_component(param + "X_"):
raise ValueError(
"Index '%s' is already in use in this model. Please choose another." % piece_index
"Index '%s' is already in use in this model. Please choose another."
% piece_index
)
i = f"{int(piece_index):04d}"

Expand All @@ -269,18 +283,17 @@ def add_piecewise_param(self,piece_index,**kwargs ):
# check if name exits and is currently available

self.add_param(
prefixParameter(
name=param + f"X_{i}",
units=param_unit,
value=paramx,
description="Parameter" + param + "variation",
parameter_type="float",
frozen=False,
)
)
prefixParameter(
name=param + f"X_{i}",
units=param_unit,
value=paramx,
description="Parameter" + param + "variation",
parameter_type="float",
frozen=False,
)
)
self.setup()


def lock_groups(self):
self.validate()
self.update_binary_object(None)
Expand Down Expand Up @@ -312,51 +325,50 @@ def setup(self):

for index in XR1_mapping.values():
index = index.split("_")[1]
piece_index = f"{int(index):04d}"
if hasattr(self,f"T0X_{piece_index}"):
if getattr(self,f"T0X_{piece_index}") is not None:
self.binary_instance.add_binary_params(f"T0X_{piece_index}",getattr(self,f"T0X_{piece_index}"))
else:
self.binary_instance.add_binary_params(f"T0X_{piece_index}",self.T0.value)


if hasattr(self,f"A1X_{piece_index}"):
if hasattr(self,f"A1X_{piece_index}"):
if getattr(self,f"A1X_{piece_index}") is not None:
self.binary_instance.add_binary_params(f"A1X_{piece_index}",getattr(self,f"A1X_{piece_index}"))
else:
self.binary_instance.add_binary_params(f"A1X_{piece_index}",self.A1.value)


if hasattr(self,f"XR1_{piece_index}"):
if getattr(self,f"XR1_{piece_index}") is not None:
self.binary_instance.add_binary_params(f"XR1_{piece_index}", getattr(self,f"XR1_{piece_index}"))
piece_index = f"{int(index):04d}"
if hasattr(self, f"T0X_{piece_index}"):
if getattr(self, f"T0X_{piece_index}") is not None:
self.binary_instance.add_binary_params(
f"T0X_{piece_index}", getattr(self, f"T0X_{piece_index}")
)
else:
raise ValueError(
f"No date provided to create a group with"
)
self.binary_instance.add_binary_params(
f"T0X_{piece_index}", self.T0.value
)

if hasattr(self, f"A1X_{piece_index}"):
if hasattr(self, f"A1X_{piece_index}"):
if getattr(self, f"A1X_{piece_index}") is not None:
self.binary_instance.add_binary_params(
f"A1X_{piece_index}", getattr(self, f"A1X_{piece_index}")
)
else:
self.binary_instance.add_binary_params(
f"A1X_{piece_index}", self.A1.value
)

if hasattr(self, f"XR1_{piece_index}"):
if getattr(self, f"XR1_{piece_index}") is not None:
self.binary_instance.add_binary_params(
f"XR1_{piece_index}", getattr(self, f"XR1_{piece_index}")
)
else:
raise ValueError(f"No date provided to create a group with")
else:
raise ValueError(
f"No name provided to create a group with"
)



if hasattr(self,f"XR2_{piece_index}"):
if getattr(self,f"XR2_{piece_index}") is not None:
self.binary_instance.add_binary_params(f"XR2_{piece_index}", getattr(self,f"XR2_{piece_index}"))
raise ValueError(f"No name provided to create a group with")

if hasattr(self, f"XR2_{piece_index}"):
if getattr(self, f"XR2_{piece_index}") is not None:
self.binary_instance.add_binary_params(
f"XR2_{piece_index}", getattr(self, f"XR2_{piece_index}")
)
else:
raise ValueError(
f"No date provided to create a group with"
)
raise ValueError(f"No date provided to create a group with")
else:
raise ValueError(
f"No name provided to create a group with"
)
raise ValueError(f"No name provided to create a group with")

self.update_binary_object(None)


def validate(self):
"""Include catches for overlapping groups. etc
Raises
Expand Down Expand Up @@ -431,20 +443,18 @@ def validate(self):
f"Group boundary overlap detected. Make sure groups are not overlapping"
)



def paramx_per_toa(self,param_name,toas):
def paramx_per_toa(self, param_name, toas):
condition = {}
tbl = toas.table
XR1_mapping = self.get_prefix_mapping_component("XR1_")
XR2_mapping = self.get_prefix_mapping_component("XR2_")
if not hasattr(self,"toas_selector"):
self.toas_selector=TOASelect(is_range=True)
if param_name[0:2] == 'T0':
if not hasattr(self, "toas_selector"):
self.toas_selector = TOASelect(is_range=True)
if param_name[0:2] == "T0":
paramX_mapping = self.get_prefix_mapping_component("T0X_")
param_unit = u.d
elif param_name[0:2] == 'A1':
paramX_mapping = self.get_prefix_mapping_component("A1X_")
elif param_name[0:2] == "A1":
paramX_mapping = self.get_prefix_mapping_component("A1X_")
param_unit = u.ls
else:
raise AttributeError(
Expand All @@ -453,21 +463,20 @@ def paramx_per_toa(self,param_name,toas):
)

for piece_index in paramX_mapping.keys():
r1 = getattr(self,XR1_mapping[piece_index]).quantity
r2 = getattr(self,XR2_mapping[piece_index]).quantity
condition[paramX_mapping[piece_index]] = (r1.mjd,r2.mjd)
select_idx = self.toas_selector.get_select_index(condition,tbl["mjd_float"])
paramx = np.zeros(len(tbl))*param_unit
r1 = getattr(self, XR1_mapping[piece_index]).quantity
r2 = getattr(self, XR2_mapping[piece_index]).quantity
condition[paramX_mapping[piece_index]] = (r1.mjd, r2.mjd)
select_idx = self.toas_selector.get_select_index(condition, tbl["mjd_float"])
paramx = np.zeros(len(tbl)) * param_unit
for k, v in select_idx.items():
paramx[v]+=getattr(self,k).quantity
paramx[v] += getattr(self, k).quantity
return paramx



def get_number_of_groups(self):
"""Get the number of piecewise parameters"""
return len(self.binary_instance.piecewise_parameter_information)

#def get_group_boundaries(self):
# def get_group_boundaries(self):
# """Get a all pieces' date boundaries.
# Returns
# -------
Expand All @@ -480,7 +489,7 @@ def get_number_of_groups(self):
# # asks the object for the number of piecewise groups
# return self.binary_instance.get_group_boundaries()

#def which_group_is_toa_in(self, toa):
# def which_group_is_toa_in(self, toa):
# """Find the group a toa belongs to based on the boundaries of groups passed to BT_piecewise
# Parameters
# ----------
Expand All @@ -493,7 +502,7 @@ def get_number_of_groups(self):
# """
# return self.binary_instance.toa_belongs_in_group(toa)

#def get_group_indexes(self):
# def get_group_indexes(self):
# """Get all the piecewise parameter labels
# Returns
# -------
Expand All @@ -507,7 +516,7 @@ def get_number_of_groups(self):
# )
# return group_indexes

#def get_T0Xs_associated_with_toas(self, toas):
# def get_T0Xs_associated_with_toas(self, toas):
# """Get a of all the piecewise T0s associated with TOAs
# Parameters
# ----------
Expand All @@ -529,7 +538,7 @@ def get_number_of_groups(self):
# self.binary_instance.group_index_array = temporary_storage
# return T0X_per_toa

#def get_A1Xs_associated_with_toas(self, toas):
# def get_A1Xs_associated_with_toas(self, toas):
# """Get a of all the piecewise A1s associated with TOAs
# Parameters
# ----------
Expand All @@ -551,7 +560,7 @@ def get_number_of_groups(self):
# self.binary_instance.group_index_array = temporary_storage
# return A1X_per_toa

#def does_toa_reference_piecewise_parameter(self, toas, param):
# def does_toa_reference_piecewise_parameter(self, toas, param):
# """Query whether a TOA/list of TOAs belong(s) to a specific group
# Parameters
# ----------
Expand All @@ -567,5 +576,3 @@ def get_number_of_groups(self):
# self.binary_instance.group_index_array = self.which_group_is_toa_in(toas)
# from_in_piece = self.binary_instance.in_piece(param)
# return from_in_piece[0]


Loading

0 comments on commit bce039c

Please sign in to comment.