Skip to content

Commit

Permalink
Refactor Object constructor and include eta_range as obj_key postfix
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhundhausen committed Jan 29, 2024
1 parent eb02720 commit 8b1da35
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 36 deletions.
19 changes: 10 additions & 9 deletions menu_tools/object_performance/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,15 @@ def test_objects(self) -> dict[str, Any]:
if not all([":" in x for x in self._cfg["test_objects"]]):
raise ValueError(f"Misconfigured obj:id key in {self.plot_name}!")

test_obj = {
x: {"base_obj": x.split(":")[0], "id": x.split(":")[1], "x_arg": x_arg}
for x, x_arg in self._cfg["test_objects"].items()
}
return self._cfg["test_objects"]

return test_obj
# DEPRECATED
# test_obj = {
# x: {"base_obj": x.split(":")[0], "id": x.split(":")[1], "x_arg": x_arg}
# for x, x_arg in self._cfg["test_objects"].items()
# }

# return test_obj

@property
def matching(self):
Expand Down Expand Up @@ -95,8 +98,6 @@ def ylabel(self):
@property
def test_object_instances(self) -> list:
test_objects = []
for obj in self._cfg["test_objects"]:
nano_obj_name = obj.split(":")[0]
obj_id_name = obj.split(":")[1]
test_objects.append(Object(nano_obj_name, obj_id_name, self.version))
for obj_key in self._cfg["test_objects"]:
test_objects.append(Object(obj_key, self.version))
return test_objects
9 changes: 3 additions & 6 deletions menu_tools/object_performance/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def _save_json(self, file_name):
if obj_key == "ref":
continue
obj = Object(
nano_obj_name=obj_key.split("_")[0],
obj_id_name=obj_key.split("_")[1],
obj_key,
version=self.version,
)

Expand Down Expand Up @@ -149,8 +148,7 @@ def _plot_efficiency_curve(self):
efficiency, yerr = self.turnon_collection.get_efficiency(obj_key)

obj = Object(
nano_obj_name=obj_key.split("_")[0],
obj_id_name=obj_key.split("_")[1],
obj_key,
version=self.version,
)

Expand Down Expand Up @@ -196,8 +194,7 @@ def _plot_iso_vs_efficiency_curve(self):
iso_vs_eff_hist = self._get_iso_vs_eff_hist(gen_hist_trig[0])

obj = Object(
nano_obj_name=obj_key.split("_")[0],
obj_id_name=obj_key.split("_")[1],
obj_key,
version=self.version,
)

Expand Down
4 changes: 3 additions & 1 deletion menu_tools/object_performance/tests/test_electron_v29.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ def test_isolation_barrel():

for key, val in reference_data.items():
if isinstance(val, dict):
if "tkEle" in key:
test_key = "tkElectron:NoIso:inclusive"
efficiencies_test = np.array(
test_result[key]["efficiency"], dtype=np.float64
test_result[test_key]["efficiency"], dtype=np.float64
)
efficiencies_reference = np.array(val["efficiency"], dtype=np.float64)
print(efficiencies_reference)
Expand Down
11 changes: 4 additions & 7 deletions menu_tools/object_performance/turnon_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def _load_test_branches(self) -> None:
"""
Load test objects.
"""
test_objects = self.cfg_plot.test_objects
for test_obj, obj_cfg in test_objects.items():
obj = Object(obj_cfg["base_obj"], obj_cfg["id"], self.cfg_plot.version)
for obj in self.cfg_plot.test_object_instances:
test_array = self._load_array_from_parquet(obj.nano_obj_name)
test_array = ak.with_name(test_array, "Momentum4D")
self.turnon_collection.ak_arrays[str(obj)] = test_array
Expand Down Expand Up @@ -99,10 +97,9 @@ def test_objects(self) -> list[tuple[Object, str]]:
obj_args = []

test_objects = self.cfg_plot.test_objects
for test_obj, obj_cfg in test_objects.items():
obj = Object(obj_cfg["base_obj"], obj_cfg["id"], self.cfg_plot.version)
x_arg = obj_cfg["x_arg"].lower()
obj_args.append((obj, x_arg))
for obj_key, x_arg in test_objects.items():
obj = Object(obj_key, self.cfg_plot.version)
obj_args.append((obj, x_arg.lower()))

return obj_args

Expand Down
8 changes: 3 additions & 5 deletions menu_tools/rate_plots/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,8 @@ def test_objects(self) -> list:
@property
def test_object_instances(self) -> dict[str, dict[str, Object]]:
test_objects: dict[str, dict[str, Object]] = {}
for obj in self._cfg["test_objects"]:
nano_obj_name = obj.split(":")[0]
obj_id_name = obj.split(":")[1]
test_objects[obj] = {}
for obj_key in self._cfg["test_objects"]:
test_objects[obj_key] = {}
for version in self.versions:
test_objects[obj][version] = Object(nano_obj_name, obj_id_name, version)
test_objects[obj_key][version] = Object(obj_key, version)
return test_objects
44 changes: 36 additions & 8 deletions menu_tools/utils/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,50 @@ class Object:
version: version of the menu
"""

def __init__(self, nano_obj_name: str, obj_id_name: str, version: str) -> None:
def __init__(
self,
object_key: str,
version: str,
) -> None:
"""Initializes an Object loading the parameters from the
corresponding config file.
Args:
nano_obj_name: name of the physics object in the l1 ntuples
obj_id_name: name of the l1 object id as defined in `configs`
object_key: object/id specifier of the form l1_object:id[:eta_range]
version: version of the menu
"""
self.nano_obj_name = nano_obj_name
self.obj_id_name = obj_id_name
self.object_key = object_key
self.version = version
self._nano_obj # fail early if no config can be found

def __str__(self) -> str:
return f"{self.nano_obj_name}_{self.obj_id_name}"
return f"{self.nano_obj_name}:{self.obj_id_name}:{self.eta_range}"

@property
def file_ext(self) -> str:
return str(self).replace(":", "_")

@property
def nano_obj_name(self) -> str:
return self.object_key.split(":")[0]

@property
def obj_id_name(self) -> str:
return self.object_key.split(":")[1]

@property
def eta_range(self) -> str:
"""If an eta range other than "inclusive" is specified, a cut to that
range is added to `cuts`.
Returns:
eta_range_key: `barrel`/`endcap`/`overlap`/`forward`/`inclusive`
"""
try:
eta_range_key = self.object_key.split(":")[2]
except IndexError:
eta_range_key = "inclusive"
return eta_range_key

@property
def _nano_obj(self) -> dict[str, dict]:
Expand Down Expand Up @@ -147,8 +175,8 @@ def compute_selection_mask_for_object_cuts(obj: Object, ak_array: ak.Array) -> a


if __name__ == "__main__":
x = Object("tkElectron", "Iso", "V29")
x = Object("caloJet", "default", "V29")
x = Object("tkElectron:Iso", "V29")
x = Object("caloJet:default", "V29")
print(x)
print(x.match_dR)
print(x.plot_label)
Expand Down

0 comments on commit 8b1da35

Please sign in to comment.