Skip to content

Commit

Permalink
Merge pull request #470 from dask-contrib/dask-tokenize-inputs
Browse files Browse the repository at this point in the history
fix: projected AwkwardInputLayers have __dask_tokenize__
  • Loading branch information
martindurant authored Feb 13, 2024
2 parents 63c003f + 1aadcfc commit 81861f8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,19 @@ def io_func_implements_report(func: ImplementsIOFunction) -> bool:
return hasattr(func, "return_report")


class AwkwardTokenizable:

def __init__(self, ret_val, parent_name):
self.parent_name = parent_name
self.ret_val = ret_val

def __dask_tokenize__(self):
return ("AwkwardTokenizable", self.parent_name)

def __call__(self, *_, **__):
return self.ret_val


class AwkwardInputLayer(AwkwardBlockwiseLayer):
"""A layer known to perform IO and produce Awkward arrays
Expand Down Expand Up @@ -230,7 +243,7 @@ def prepare_for_projection(self) -> tuple[AwkwardInputLayer, TypeTracerReport, T
new_input_layer = AwkwardInputLayer(
name=self.name,
inputs=[None][: int(list(self.numblocks.values())[0][0])],
io_func=lambda *_, **__: new_return,
io_func=AwkwardTokenizable(new_return, self.name),
label=self.label,
produces_tasks=self.produces_tasks,
creation_info=self.creation_info,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def test_basic_root_works():
.fill(events.MET_pt)
)

columns = list(dak.necessary_columns(q1_hist).values())[0]
assert columns == frozenset({"MET_pt"})
dask.compute(q1_hist)


Expand Down

0 comments on commit 81861f8

Please sign in to comment.