diff --git a/src/dask_awkward/layers/layers.py b/src/dask_awkward/layers/layers.py index 89c49c41..92441443 100644 --- a/src/dask_awkward/layers/layers.py +++ b/src/dask_awkward/layers/layers.py @@ -117,6 +117,19 @@ def io_func_implements_report(func: ImplementsIOFunction) -> bool: return hasattr(func, "return_report") +class AwkwardTokenizable: + + def __init__(self, ret_val, parent_name): + self.parent_name = parent_name + self.ret_val = ret_val + + def __dask_tokenize__(self): + return ("AwkwardTokenizable", self.parent_name) + + def __call__(self, *_, **__): + return self.ret_val + + class AwkwardInputLayer(AwkwardBlockwiseLayer): """A layer known to perform IO and produce Awkward arrays @@ -230,7 +243,7 @@ def prepare_for_projection(self) -> tuple[AwkwardInputLayer, TypeTracerReport, T new_input_layer = AwkwardInputLayer( name=self.name, inputs=[None][: int(list(self.numblocks.values())[0][0])], - io_func=lambda *_, **__: new_return, + io_func=AwkwardTokenizable(new_return, self.name), label=self.label, produces_tasks=self.produces_tasks, creation_info=self.creation_info, diff --git a/tests/test_inspect.py b/tests/test_inspect.py index 6ffc7ed9..b4e2651b 100644 --- a/tests/test_inspect.py +++ b/tests/test_inspect.py @@ -97,6 +97,8 @@ def test_basic_root_works(): .fill(events.MET_pt) ) + columns = list(dak.necessary_columns(q1_hist).values())[0] + assert columns == frozenset({"MET_pt"}) dask.compute(q1_hist)