Skip to content

Commit

Permalink
make more generic to handle new tiers
Browse files Browse the repository at this point in the history
  • Loading branch information
ggmarshall committed Oct 9, 2024
1 parent 981877e commit 2337326
Showing 1 changed file with 48 additions and 43 deletions.
91 changes: 48 additions & 43 deletions src/pygama/evt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
H5DataLoc = namedtuple(
"H5DataLoc", ("file", "group", "table_fmt"), defaults=3 * (None,)
)

DataInfo = namedtuple(
"DataInfo", ("raw", "tcm", "dsp", "hit", "evt"), defaults=5 * (None,)
)
DataInfo = namedtuple("DataInfo", ("raw"), defaults=1 * (None,))

TCMData = namedtuple("TCMData", ("id", "idx", "cumulative_length"))


def make_files_config(data: dict):
if not isinstance(data, DataInfo):
if not isinstance(data, tuple):
DataInfo = namedtuple(
"DataInfo", tuple(data.keys()), defaults=len(data.keys()) * (None,)
)
return DataInfo(
*[
H5DataLoc(*data[tier]) if tier in data else H5DataLoc()
Expand Down Expand Up @@ -72,7 +72,7 @@ def find_parameters(
idx_ch,
field_list,
) -> dict:
"""Finds and returns parameters from `hit` and `dsp` tiers.
"""Finds and returns parameters from non `tcm`, `evt` tiers.
Parameters
----------
Expand All @@ -83,43 +83,38 @@ def find_parameters(
idx_ch
index array of entries to be read from datainfo.
field_list
list of tuples ``(tier, field)`` to be found in the `hit/dsp` tiers.
list of tuples ``(tier, field)`` to be found in non `tcm`, `evt` tiers.
"""
f = make_files_config(datainfo)

# find fields in either dsp, hit
dsp_flds = [e[1] for e in field_list if e[0] == f.dsp.group]
hit_flds = [e[1] for e in field_list if e[0] == f.hit.group]
final_dict = {}

hit_dict, dsp_dict = {}, {}
for name, tier in f._asdict().items():
if name not in ["tcm", "evt"] and tier.file is not None: # skip other tables
keys = [
k.split("/")[-1]
for k in lh5.ls(tier.file, f"{ch.replace('/', '')}/{tier.group}/")
]
flds = [e[1] for e in field_list if e[0] == name and e[1] in keys]

if len(hit_flds) > 0:
hit_ak = lh5.read_as(
f"{ch.replace('/', '')}/{f.hit.group}/",
f.hit.file,
field_mask=hit_flds,
idx=idx_ch,
library="ak",
)
if len(flds) > 0:
tier_ak = lh5.read_as(
f"{ch.replace('/', '')}/{tier.group}/",
tier.file,
field_mask=flds,
idx=idx_ch,
library="ak",
)

hit_dict = dict(
zip([f"{f.hit.group}_" + e for e in ak.fields(hit_ak)], ak.unzip(hit_ak))
)
tier_dict = dict(
zip(
[f"{name}_" + e for e in ak.fields(tier_ak)],
ak.unzip(tier_ak),
)
)
final_dict = final_dict | tier_dict

if len(dsp_flds) > 0:
dsp_ak = lh5.read_as(
f"{ch.replace('/', '')}/{f.dsp.group}/",
f.dsp.file,
field_mask=dsp_flds,
idx=idx_ch,
library="ak",
)

dsp_dict = dict(
zip([f"{f.dsp.group}_" + e for e in ak.fields(dsp_ak)], ak.unzip(dsp_ak))
)

return hit_dict | dsp_dict
return final_dict


def get_data_at_channel(
Expand Down Expand Up @@ -178,10 +173,14 @@ def get_data_at_channel(

# evaluate expression
# move tier+dots in expression to underscores (e.g. evt.foo -> evt_foo)

new_expr = expr
for name in f._asdict():
if name not in ["tcm", "raw"]:
new_expr = new_expr.replace(f"{name}.", f"{name}_")

res = eval(
expr.replace(f"{f.dsp.group}.", f"{f.dsp.group}_")
.replace(f"{f.hit.group}.", f"{f.hit.group}_")
.replace(f"{f.evt.group}.", ""),
new_expr,
var,
)

Expand Down Expand Up @@ -231,17 +230,23 @@ def get_mask_from_query(

# get sub evt based query condition if needed
if isinstance(query, str):
query_lst = re.findall(r"(hit|dsp).([a-zA-Z_$][\w$]*)", query)
query_lst = re.findall(
rf"({'|'.join(f._asdict().keys())}).([a-zA-Z_$][\w$]*)", query
)
query_var = find_parameters(
datainfo=datainfo,
ch=ch,
idx_ch=idx_ch,
field_list=query_lst,
)

new_query = query
for name in f._asdict():
if name not in ["tcm", "evt"]:
new_query = new_query.replace(f"{name}.", f"{name}_")

limarr = eval(
query.replace(f"{f.dsp.group}.", f"{f.dsp.group}_").replace(
f"{f.hit.group}.", f"{f.hit.group}_"
),
new_query,
query_var,
)

Expand Down

0 comments on commit 2337326

Please sign in to comment.