From 8b1da350eec2e81498c97cfd31db1be618fcb09b Mon Sep 17 00:00:00 2001 From: Daniel Hundhausen Date: Mon, 29 Jan 2024 15:39:40 +0100 Subject: [PATCH] Refactor Object constructor and include eta_range as obj_key postfix --- menu_tools/object_performance/config.py | 19 ++++---- menu_tools/object_performance/plotter.py | 9 ++-- .../tests/test_electron_v29.py | 4 +- .../object_performance/turnon_collection.py | 11 ++--- menu_tools/rate_plots/config.py | 8 ++-- menu_tools/utils/objects.py | 44 +++++++++++++++---- 6 files changed, 59 insertions(+), 36 deletions(-) diff --git a/menu_tools/object_performance/config.py b/menu_tools/object_performance/config.py index d5da2b62..a51dd89e 100644 --- a/menu_tools/object_performance/config.py +++ b/menu_tools/object_performance/config.py @@ -53,12 +53,15 @@ def test_objects(self) -> dict[str, Any]: if not all([":" in x for x in self._cfg["test_objects"]]): raise ValueError(f"Misconfigured obj:id key in {self.plot_name}!") - test_obj = { - x: {"base_obj": x.split(":")[0], "id": x.split(":")[1], "x_arg": x_arg} - for x, x_arg in self._cfg["test_objects"].items() - } + return self._cfg["test_objects"] - return test_obj + # DEPRECATED + # test_obj = { + # x: {"base_obj": x.split(":")[0], "id": x.split(":")[1], "x_arg": x_arg} + # for x, x_arg in self._cfg["test_objects"].items() + # } + + # return test_obj @property def matching(self): @@ -95,8 +98,6 @@ def ylabel(self): @property def test_object_instances(self) -> list: test_objects = [] - for obj in self._cfg["test_objects"]: - nano_obj_name = obj.split(":")[0] - obj_id_name = obj.split(":")[1] - test_objects.append(Object(nano_obj_name, obj_id_name, self.version)) + for obj_key in self._cfg["test_objects"]: + test_objects.append(Object(obj_key, self.version)) return test_objects diff --git a/menu_tools/object_performance/plotter.py b/menu_tools/object_performance/plotter.py index d2a50a0b..9a7608b4 100755 --- a/menu_tools/object_performance/plotter.py +++ b/menu_tools/object_performance/plotter.py @@ -81,8 +81,7 @@ def _save_json(self, file_name): if obj_key == "ref": continue obj = Object( - nano_obj_name=obj_key.split("_")[0], - obj_id_name=obj_key.split("_")[1], + obj_key, version=self.version, ) @@ -149,8 +148,7 @@ def _plot_efficiency_curve(self): efficiency, yerr = self.turnon_collection.get_efficiency(obj_key) obj = Object( - nano_obj_name=obj_key.split("_")[0], - obj_id_name=obj_key.split("_")[1], + obj_key, version=self.version, ) @@ -196,8 +194,7 @@ def _plot_iso_vs_efficiency_curve(self): iso_vs_eff_hist = self._get_iso_vs_eff_hist(gen_hist_trig[0]) obj = Object( - nano_obj_name=obj_key.split("_")[0], - obj_id_name=obj_key.split("_")[1], + obj_key, version=self.version, ) diff --git a/menu_tools/object_performance/tests/test_electron_v29.py b/menu_tools/object_performance/tests/test_electron_v29.py index f667102b..2a0788cc 100644 --- a/menu_tools/object_performance/tests/test_electron_v29.py +++ b/menu_tools/object_performance/tests/test_electron_v29.py @@ -35,8 +35,10 @@ def test_isolation_barrel(): for key, val in reference_data.items(): if isinstance(val, dict): + if "tkEle" in key: + test_key = "tkElectron:NoIso:inclusive" efficiencies_test = np.array( - test_result[key]["efficiency"], dtype=np.float64 + test_result[test_key]["efficiency"], dtype=np.float64 ) efficiencies_reference = np.array(val["efficiency"], dtype=np.float64) print(efficiencies_reference) diff --git a/menu_tools/object_performance/turnon_collection.py b/menu_tools/object_performance/turnon_collection.py index b9cb74c0..50e13556 100644 --- a/menu_tools/object_performance/turnon_collection.py +++ b/menu_tools/object_performance/turnon_collection.py @@ -63,9 +63,7 @@ def _load_test_branches(self) -> None: """ Load test objects. """ - test_objects = self.cfg_plot.test_objects - for test_obj, obj_cfg in test_objects.items(): - obj = Object(obj_cfg["base_obj"], obj_cfg["id"], self.cfg_plot.version) + for obj in self.cfg_plot.test_object_instances: test_array = self._load_array_from_parquet(obj.nano_obj_name) test_array = ak.with_name(test_array, "Momentum4D") self.turnon_collection.ak_arrays[str(obj)] = test_array @@ -99,10 +97,9 @@ def test_objects(self) -> list[tuple[Object, str]]: obj_args = [] test_objects = self.cfg_plot.test_objects - for test_obj, obj_cfg in test_objects.items(): - obj = Object(obj_cfg["base_obj"], obj_cfg["id"], self.cfg_plot.version) - x_arg = obj_cfg["x_arg"].lower() - obj_args.append((obj, x_arg)) + for obj_key, x_arg in test_objects.items(): + obj = Object(obj_key, self.cfg_plot.version) + obj_args.append((obj, x_arg.lower())) return obj_args diff --git a/menu_tools/rate_plots/config.py b/menu_tools/rate_plots/config.py index 26d9a9f9..461336ab 100644 --- a/menu_tools/rate_plots/config.py +++ b/menu_tools/rate_plots/config.py @@ -44,10 +44,8 @@ def test_objects(self) -> list: @property def test_object_instances(self) -> dict[str, dict[str, Object]]: test_objects: dict[str, dict[str, Object]] = {} - for obj in self._cfg["test_objects"]: - nano_obj_name = obj.split(":")[0] - obj_id_name = obj.split(":")[1] - test_objects[obj] = {} + for obj_key in self._cfg["test_objects"]: + test_objects[obj_key] = {} for version in self.versions: - test_objects[obj][version] = Object(nano_obj_name, obj_id_name, version) + test_objects[obj_key][version] = Object(obj_key, version) return test_objects diff --git a/menu_tools/utils/objects.py b/menu_tools/utils/objects.py index 37ad6cf6..3ce868fd 100644 --- a/menu_tools/utils/objects.py +++ b/menu_tools/utils/objects.py @@ -17,22 +17,50 @@ class Object: version: version of the menu """ - def __init__(self, nano_obj_name: str, obj_id_name: str, version: str) -> None: + def __init__( + self, + object_key: str, + version: str, + ) -> None: """Initializes an Object loading the parameters from the corresponding config file. Args: - nano_obj_name: name of the physics object in the l1 ntuples - obj_id_name: name of the l1 object id as defined in `configs` + object_key: object/id specifier of the form l1_object:id[:eta_range] version: version of the menu """ - self.nano_obj_name = nano_obj_name - self.obj_id_name = obj_id_name + self.object_key = object_key self.version = version self._nano_obj # fail early if no config can be found def __str__(self) -> str: - return f"{self.nano_obj_name}_{self.obj_id_name}" + return f"{self.nano_obj_name}:{self.obj_id_name}:{self.eta_range}" + + @property + def file_ext(self) -> str: + return str(self).replace(":", "_") + + @property + def nano_obj_name(self) -> str: + return self.object_key.split(":")[0] + + @property + def obj_id_name(self) -> str: + return self.object_key.split(":")[1] + + @property + def eta_range(self) -> str: + """If an eta range other than "inclusive" is specified, a cut to that + range is added to `cuts`. + + Returns: + eta_range_key: `barrel`/`endcap`/`overlap`/`forward`/`inclusive` + """ + try: + eta_range_key = self.object_key.split(":")[2] + except IndexError: + eta_range_key = "inclusive" + return eta_range_key @property def _nano_obj(self) -> dict[str, dict]: @@ -147,8 +175,8 @@ def compute_selection_mask_for_object_cuts(obj: Object, ak_array: ak.Array) -> a if __name__ == "__main__": - x = Object("tkElectron", "Iso", "V29") - x = Object("caloJet", "default", "V29") + x = Object("tkElectron:Iso", "V29") + x = Object("caloJet:default", "V29") print(x) print(x.match_dR) print(x.plot_label)