Skip to content

Commit

Permalink
Get rid of the - on flags in API
Browse files Browse the repository at this point in the history
  • Loading branch information
aarchiba committed Aug 23, 2021
1 parent 4f8d331 commit 74da7bb
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 50 deletions.
4 changes: 2 additions & 2 deletions src/pint/models/jump.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def add_jump_and_flags(self, toa_flags, flag="gui_jump", flag_value=None):
"""
in_use = set()
for pm in self.jumps:
if pm.flag == "-" + flag:
if pm.flag == flag:
in_use.add(pm.flag_value)
if flag_value is None:
i = 1
Expand All @@ -231,7 +231,7 @@ def add_jump_and_flags(self, toa_flags, flag="gui_jump", flag_value=None):
param = maskParameter(
name="JUMP",
index=i,
flag="-" + flag,
flag=flag,
flag_value=flag_value,
value=0.0,
units="second",
Expand Down
44 changes: 24 additions & 20 deletions src/pint/models/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
time_to_longdouble,
time_to_mjd_string,
)
from pint.toa import FlagDict
from pint.utils import split_prefixed_name

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -1529,14 +1530,7 @@ def validate(flag, flag_value):
elif flag == "tel":
return get_observatory(flag_value).name
else:
if not isinstance(flag_value, str):
raise ValueError(
f"When selecting by {flag} one must supply a "
f"string not {flag_value}"
)
# FlagDict.validate() once #1074 is merged.
if len(flag_value.split()) != 1:
raise ValueError(f"Flag value {repr(flag_value)} cannot occur in TOAs.")
FlagDict.check_allowed_value(flag, flag_value)
return flag_value

@staticmethod
Expand Down Expand Up @@ -1568,11 +1562,7 @@ def flag(self, flag):
)
else:
flag = flag.lower()
if flag not in self.allowed_non_flags and not flag.startswith("-"):
raise ValueError(
f"Flags must be indicated with -, and the only "
f"non-flags allowed are {maskParameter.allowed_non_flags}."
)
FlagDict.check_allowed_key(flag)
self._flag = flag

@property
Expand Down Expand Up @@ -1633,10 +1623,10 @@ def from_parfile_line(self, line):
-----
The accepted formats for most flags::
NAME flag flag_value parameter_value
NAME flag flag_value parameter_value fit_flag
NAME flag flag_value parameter_value fit_flag uncertainty
NAME flag flag_value parameter_value uncertainty
NAME -flag flag_value parameter_value
NAME -flag flag_value parameter_value fit_flag
NAME -flag flag_value parameter_value fit_flag uncertainty
NAME -flag flag_value parameter_value uncertainty
If the flag is one of MJD or FREQ then::
Expand All @@ -1659,13 +1649,23 @@ def from_parfile_line(self, line):
return False

try:
self.flag = k[1]
flag = k[1].lower()
except IndexError:
raise ValueError(
"{}: No flag found on timfile line {!r}".format(self.name, line)
)
if flag in self.allowed_non_flags:
self.flag = flag
elif not flag.startswith("-"):
raise ValueError(
f"Flags must be indicated with -, and the only "
f"non-flags allowed are {maskParameter.allowed_non_flags}."
)
else:
# Strip leading -
self.flag = flag[1:]

if self.flag in self.wants_two_values:
if flag in self.wants_two_values:
flag_value_str = k[2], k[3]
len_flag_v = 2
else:
Expand Down Expand Up @@ -1703,7 +1703,11 @@ def as_parfile_line(self):
name = self.origin_name
else:
name = self.use_alias
line = "%-15s %s " % (name, self.flag)
if self.flag in self.allowed_non_flags:
flag = self.flag
else:
flag = "-" + self.flag
line = "%-15s %s " % (name, flag)
if isinstance(self.flag_value, str):
line += self.flag_value
else:
Expand Down
8 changes: 4 additions & 4 deletions src/pint/models/timing_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,9 +1289,9 @@ def jump_flags_to_params(self, toas, flag="jump"):

new_jumps = []
# check if any TOAs are jumped
jumped = set(
flag_dict[flag] for flag_dict in toas.table["flags"] if flag in flag_dict
)
jumped = set(toas[flag])
if "" in jumped:
jumped.remove("")
if not jumped:
log.info("No jump flags to process from .tim file")
return new_jumps
Expand All @@ -1318,7 +1318,7 @@ def jump_flags_to_params(self, toas, flag="jump"):
param = maskParameter(
name="JUMP",
index=next_free_index,
flag="-" + flag,
flag=flag,
flag_value=j,
value=0.0,
units="second",
Expand Down
12 changes: 6 additions & 6 deletions tests/test_dmefac_dmequad.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_no_efact_noequad(test_toas, test_model):


def test_only_one_efact(test_toas, test_model):
test_model.DMEFAC1.flag = "-fe"
test_model.DMEFAC1.flag = "fe"
test_model.DMEFAC1.flag_value = "Rcvr_800"
test_model.DMEFAC1.value = 10
test_model.setup()
Expand All @@ -50,7 +50,7 @@ def test_only_one_efact(test_toas, test_model):


def test_only_one_equad(test_toas, test_model):
test_model.DMEQUAD1.flag = "-fe"
test_model.DMEQUAD1.flag = "fe"
test_model.DMEQUAD1.flag_value = "YUPPI"
test_model.DMEQUAD1.value = 10
test_model.setup()
Expand All @@ -65,10 +65,10 @@ def test_only_one_equad(test_toas, test_model):


def test_only_one_equad_one_efact_same_backend(test_toas, test_model):
test_model.DMEQUAD1.flag = "-fe"
test_model.DMEQUAD1.flag = "fe"
test_model.DMEQUAD1.flag_value = "Rcvr_800"
test_model.DMEQUAD1.value = 10
test_model.DMEFAC1.flag = "-fe"
test_model.DMEFAC1.flag = "fe"
test_model.DMEFAC1.flag_value = "Rcvr_800"
test_model.DMEFAC1.value = 10
test_model.setup()
Expand All @@ -88,10 +88,10 @@ def test_only_one_equad_one_efact_same_backend(test_toas, test_model):


def test_only_one_equad_one_efact_different_backend(test_toas, test_model):
test_model.DMEQUAD1.flag = "-fe"
test_model.DMEQUAD1.flag = "fe"
test_model.DMEQUAD1.flag_value = "Rcvr_800"
test_model.DMEQUAD1.value = 10
test_model.DMEFAC1.flag = "-fe"
test_model.DMEFAC1.flag = "fe"
test_model.DMEFAC1.flag_value = "YUPPI"
test_model.DMEFAC1.value = 20
test_model.setup()
Expand Down
6 changes: 1 addition & 5 deletions tests/test_flagging_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,10 @@ def test_jump_by_cluster(setup_NGC6440E):
m_copy.add_component(PhaseJump(), validate=False)
cp_copy = m_copy.components["PhaseJump"]
par_copy = p.maskParameter(
name="JUMP", flag="-toacluster", value=0.2, flag_value="41", units=u.s
name="JUMP", flag="toacluster", value=0.2, flag_value="41", units=u.s
)
# this should be identical to the cluster above
cp_copy.add_param(par_copy, setup=True)
assert set(cp.JUMP1.select_toa_mask(setup_NGC6440E.t)) == set(
cp_copy.JUMP1.select_toa_mask(setup_NGC6440E.t)
)


if __name__ == "__main__":
pass
4 changes: 2 additions & 2 deletions tests/test_jump.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_add_jumps_and_flags(setup_NGC6440E):
selected_toa_ind2 = [10, 11, 12]
j2 = cp.add_jump_and_flags(setup_NGC6440E.t.table["flags"][selected_toa_ind2])
jp2 = getattr(cp, j2)
assert jp2.flag == "-gui_jump"
assert jp2.flag == "gui_jump"
assert jp2.flag_value == "2"
# check previous jump flags unaltered
for d in setup_NGC6440E.t.table["flags"][selected_toa_ind]:
Expand Down Expand Up @@ -274,7 +274,7 @@ def test_multiple_jumps_add():
[
pint.models.parameter.maskParameter(
name="JUMP",
flag="-fish",
flag="fish",
flag_value="carp",
units=u.s,
value=7,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_mask_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ def test_name_mask(toas):

def test_flag_mask(toas):
with pytest.raises(ValueError):
maskParameter("test2", flag="-fe", flag_value=430)
mp_flag2 = maskParameter("test2", flag="-fe", flag_value="430")
maskParameter("test2", flag="fe", flag_value=430)
mp_flag2 = maskParameter("test2", flag="fe", flag_value="430")
assert mp_flag2.flag_value == "430"
with pytest.raises(ValueError):
maskParameter("test2", flag="fe", flag_value="430")
mp_flag3 = maskParameter("test2", flag="-fe", flag_value="L-wide")
maskParameter("test2", flag="-fe", flag_value="430")
mp_flag3 = maskParameter("test2", flag="fe", flag_value="L-wide")
assert mp_flag3.flag_value == "L-wide"
select_toas = mp_flag3.select_toa_mask(toas)
assert len(select_toas) > 0
Expand Down
11 changes: 6 additions & 5 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_read_par_line_expected_values():
assert_allclose(test_m.JUMP7.uncertainty_value, 10.5)
assert_allclose(test_m.JUMP6.value, 0.1)
assert_allclose(test_m.JUMP6.uncertainty_value, 10.0)
assert test_m.JUMP12.flag == "-testflag"
assert test_m.JUMP12.flag == "testflag"
assert not test_m.JUMP12.frozen
assert test_m.JUMP12.flag_value == "flagvalue"
assert_allclose(test_m.JUMP12.value, 0.1)
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_read_par_line(self):
self.assertTrue(np.isclose(test_m.JUMP7.uncertainty_value, 10.5))
self.assertTrue(np.isclose(test_m.JUMP6.value, 0.1))
self.assertTrue(np.isclose(test_m.JUMP6.uncertainty_value, 10.0))
self.assertEqual(test_m.JUMP12.flag, "-testflag")
self.assertEqual(test_m.JUMP12.flag, "testflag")
self.assertEqual(test_m.JUMP12.frozen, False)
self.assertEqual(test_m.JUMP12.flag_value, "flagvalue")
self.assertTrue(np.isclose(test_m.JUMP12.value, 0.1))
Expand Down Expand Up @@ -600,16 +600,17 @@ def test_parameter_can_be_pickled(p):
("freq", (1000.0 * u.MHz, 2000.0 * u.MHz)),
("mjd", (57000.0, 58000.0)),
("mjd", [57000.0, 58000.0]),
("-fish", "carp"),
("fish", "carp"),
("freq", (2000.0 * u.MHz, np.inf * u.MHz)),
("freq", np.array([1000, 2000], dtype=np.longdouble) * u.MHz),
]
invalid_settings = [
("tel", (10, 20)),
("freq", "ao"),
("mjd", ([], [])),
("-fish", (1, 2)),
("-fish", ["c", "a", "r", "p"]),
("-fish", "carp"),
("fish", (1, 2)),
("fish", ["c", "a", "r", "p"]),
]


Expand Down
4 changes: 2 additions & 2 deletions tests/test_timing_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,9 @@ def test_jump_flags_to_params(timfile_jumps, timfile_nojumps, model_0437):
m.jump_flags_to_params(t)
assert "PhaseJump" in m.components
assert len(m.components["PhaseJump"].jumps) == 2
assert ("-jump", "1") in [
assert ("jump", "1") in [
(j.flag, j.flag_value) for j in m.components["PhaseJump"].jumps
]
assert ("-jump", "2") in [
assert ("jump", "2") in [
(j.flag, j.flag_value) for j in m.components["PhaseJump"].jumps
]

0 comments on commit 74da7bb

Please sign in to comment.