From cda823c32891bf1728339d5f159cc50270543b72 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Wed, 21 Feb 2024 10:32:37 -0500 Subject: [PATCH 1/9] Sketch --- src/dask_awkward/lib/core.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index a6111867..9fda855d 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -875,8 +875,27 @@ def __init__( ) -> None: self._dask: HighLevelGraph = dsk self._name: str = name + self._queued = None self._divisions: tuple[int, ...] | tuple[None, ...] = divisions - self._meta: ak.Array = meta + self._base_meta: ak.Array = meta + self._getitem_staged = () + + @property + def _meta(self): + if self._getitem_staged: + newobj = self._getitem_trivial_map_partitions( + self._getitem_staged, execute=True + ) + self._getitem_staged = () + self._meta = newobj._meta + self._dask = newobj.dask + return self._base_meta + + @_meta.setter + def _meta(self, meta): + if self._getitem_staged: + raise ValueError("Cannot set _meta with staged getitems (internal)") + self._base_meta = meta def __dask_graph__(self) -> HighLevelGraph: return self.dask @@ -1045,6 +1064,8 @@ def __reduce__(self): @property def dask(self) -> HighLevelGraph: """High level task graph associated with the collection.""" + if self._getitem_staged is not None: + self._meta return self._dask @property @@ -1204,14 +1225,21 @@ def _getitem_trivial_map_partitions( where: Any, meta: Any | None = None, label: str | None = None, + execute: bool = False, ) -> Any: - if meta is None and self._meta is not None: + import copy + + if not execute: + newobj = copy.copy(self) # shallow/fast + newobj._getitem_staged = where + return newobj + if meta is None and self._base_meta is not None: if isinstance(where, tuple): metad = to_meta(where) - meta = self._meta[metad] + meta = self._base_meta[metad] else: m = to_meta([where])[0] - meta = self._meta[m] + meta = self._base_meta[m] return map_partitions( operator.getitem, self, From b0015d3ac54f6de6a4e42e08951989ed0554a0d6 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Wed, 21 Feb 2024 10:34:02 -0500 Subject: [PATCH 2/9] remove atr --- src/dask_awkward/lib/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index 9fda855d..8a56b48c 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -875,7 +875,6 @@ def __init__( ) -> None: self._dask: HighLevelGraph = dsk self._name: str = name - self._queued = None self._divisions: tuple[int, ...] | tuple[None, ...] = divisions self._base_meta: ak.Array = meta self._getitem_staged = () From 0f0ca22339262a04a5a96bcbc92b7e26fd7b2c8f Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Wed, 21 Feb 2024 10:37:21 -0500 Subject: [PATCH 3/9] chain --- src/dask_awkward/lib/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index 8a56b48c..c575ab7b 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -1230,7 +1230,8 @@ def _getitem_trivial_map_partitions( if not execute: newobj = copy.copy(self) # shallow/fast - newobj._getitem_staged = where + where = where if isinstance(where, tuple) else (where,) + newobj._getitem_staged = self._getitem_staged + where return newobj if meta is None and self._base_meta is not None: if isinstance(where, tuple): From 711868a417fa4c738bf1c1810165e8aae42474d4 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Fri, 23 Feb 2024 11:46:51 -0500 Subject: [PATCH 4/9] simple cache --- pyproject.toml | 3 +- src/dask_awkward/lib/core.py | 194 ++++++++++++++++------------------- 2 files changed, 89 insertions(+), 108 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9eea6a52..d92511d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,7 +128,8 @@ warn_unreachable = true "pyarrow.*", "tlz.*", "uproot.*", - "cloudpickle.*" + "cloudpickle.*", + "cachetools.*" ] ignore_missing_imports = true diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index c575ab7b..dbd32052 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload import awkward as ak +import cachetools import dask.config import numpy as np from awkward._do import remove_structure as ak_do_remove_structure @@ -439,6 +440,7 @@ def __repr__(self) -> str: # pragma: no cover return self.__str__() def __str__(self) -> str: + self._meta # force updating from any staged ops if self.known_value is not None: return ( f"dask.awkward<{key_split(self.name)}, " @@ -854,6 +856,9 @@ def _finalize_array(results: Sequence[Any]) -> Any: raise RuntimeError(msg) +dak_cache = cachetools.LRUCache(maxsize=1000) + + class Array(DaskMethodsMixin, NDArrayOperatorsMixin): """Partitioned, lazy, and parallel Awkward Array Dask collection. @@ -876,25 +881,7 @@ def __init__( self._dask: HighLevelGraph = dsk self._name: str = name self._divisions: tuple[int, ...] | tuple[None, ...] = divisions - self._base_meta: ak.Array = meta - self._getitem_staged = () - - @property - def _meta(self): - if self._getitem_staged: - newobj = self._getitem_trivial_map_partitions( - self._getitem_staged, execute=True - ) - self._getitem_staged = () - self._meta = newobj._meta - self._dask = newobj.dask - return self._base_meta - - @_meta.setter - def _meta(self, meta): - if self._getitem_staged: - raise ValueError("Cannot set _meta with staged getitems (internal)") - self._base_meta = meta + self._meta: ak.Array = meta def __dask_graph__(self) -> HighLevelGraph: return self.dask @@ -1063,8 +1050,6 @@ def __reduce__(self): @property def dask(self) -> HighLevelGraph: """High level task graph associated with the collection.""" - if self._getitem_staged is not None: - self._meta return self._dask @property @@ -1135,7 +1120,7 @@ def mask(self) -> AwkwardMask: @property def fields(self) -> list[str]: """Record field names (if any).""" - return ak.fields(self._meta) + return getattr(self._meta, "fields", None) or [] @property def form(self) -> Form: @@ -1224,22 +1209,14 @@ def _getitem_trivial_map_partitions( where: Any, meta: Any | None = None, label: str | None = None, - execute: bool = False, ) -> Any: - import copy - - if not execute: - newobj = copy.copy(self) # shallow/fast - where = where if isinstance(where, tuple) else (where,) - newobj._getitem_staged = self._getitem_staged + where - return newobj - if meta is None and self._base_meta is not None: + if meta is None and self._meta is not None: if isinstance(where, tuple): metad = to_meta(where) - meta = self._base_meta[metad] + meta = self._meta[metad] else: m = to_meta([where])[0] - meta = self._base_meta[m] + meta = self._meta[m] return map_partitions( operator.getitem, self, @@ -1999,96 +1976,99 @@ def map_partitions( This is effectively the same as `d = c * a` """ - opt_touch_all = kwargs.pop("opt_touch_all", None) - if opt_touch_all is not None: - warnings.warn( - "The opt_touch_all argument does nothing.\n" - "This warning will be removed in a future version of dask-awkward " - "and the function call will likely fail." - ) - - token = token or tokenize(base_fn, *args, meta, **kwargs) + token = token or tokenize( + base_fn, *args, meta is not None and meta.typestr, **kwargs + ) label = hyphenize(label or funcname(base_fn)) name = f"{label}-{token}" - kwarg_flat_deps, kwarg_repacker = unpack_collections(kwargs, traverse=traverse) - flat_deps, _ = unpack_collections(*args, *kwargs.values(), traverse=traverse) - - if len(flat_deps) == 0: - message = ( - "map_partitions expects at least one Dask collection instance, " - "you are passing non-Dask collections to dask-awkward code.\n" - "observed argument types:\n" - ) - for arg in args: - message += f"- {type(arg)}" - raise TypeError(message) + if name in dak_cache: + (hlg, meta, in_divisions, in_npartitions) = dak_cache[name] + else: + opt_touch_all = kwargs.pop("opt_touch_all", None) + if opt_touch_all is not None: + warnings.warn( + "The opt_touch_all argument does nothing.\n" + "This warning will be removed in a future version of dask-awkward " + "and the function call will likely fail." + ) - arg_flat_deps_expanded = [] - arg_repackers = [] - arg_lens_for_repackers = [] - for arg in args: - this_arg_flat_deps, repacker = unpack_collections(arg, traverse=traverse) - if ( - len(this_arg_flat_deps) > 0 - ): # if the deps list is empty this arg does not contain any dask collection, no need to repack! - arg_flat_deps_expanded.extend(this_arg_flat_deps) - arg_repackers.append(repacker) - arg_lens_for_repackers.append(len(this_arg_flat_deps)) - else: - arg_flat_deps_expanded.append(arg) - arg_repackers.append(None) - arg_lens_for_repackers.append(1) - - fn = ArgsKwargsPackedFunction( - base_fn, - arg_repackers, - kwarg_repacker, - arg_lens_for_repackers, - ) + kwarg_flat_deps, kwarg_repacker = unpack_collections(kwargs, traverse=traverse) + flat_deps, _ = unpack_collections(*args, *kwargs.values(), traverse=traverse) - lay = partitionwise_layer( - fn, - name, - *arg_flat_deps_expanded, - *kwarg_flat_deps, - ) + if len(flat_deps) == 0: + message = ( + "map_partitions expects at least one Dask collection instance, " + "you are passing non-Dask collections to dask-awkward code.\n" + "observed argument types:\n" + ) + for arg in args: + message += f"- {type(arg)}" + raise TypeError(message) - if meta is None: - meta = map_meta(fn, *arg_flat_deps_expanded, *kwarg_flat_deps) + arg_flat_deps_expanded = [] + arg_repackers = [] + arg_lens_for_repackers = [] + for arg in args: + this_arg_flat_deps, repacker = unpack_collections(arg, traverse=traverse) + if ( + len(this_arg_flat_deps) > 0 + ): # if the deps list is empty this arg does not contain any dask collection, no need to repack! + arg_flat_deps_expanded.extend(this_arg_flat_deps) + arg_repackers.append(repacker) + arg_lens_for_repackers.append(len(this_arg_flat_deps)) + else: + arg_flat_deps_expanded.append(arg) + arg_repackers.append(None) + arg_lens_for_repackers.append(1) + + fn = ArgsKwargsPackedFunction( + base_fn, + arg_repackers, + kwarg_repacker, + arg_lens_for_repackers, + ) - hlg = HighLevelGraph.from_collections( - name, - lay, - dependencies=flat_deps, - ) + lay = partitionwise_layer( + fn, + name, + *arg_flat_deps_expanded, + *kwarg_flat_deps, + ) - dak_arrays = tuple(filter(lambda x: isinstance(x, Array), flat_deps)) - if len(dak_arrays) == 0: - raise TypeError( - "at least one argument passed to map_partitions " - "should be a dask_awkward.Array collection." + if meta is None: + meta = map_meta(fn, *arg_flat_deps_expanded, *kwarg_flat_deps) + + hlg = HighLevelGraph.from_collections( + name, + lay, + dependencies=flat_deps, ) - in_npartitions = dak_arrays[0].npartitions - in_divisions = dak_arrays[0].divisions + + dak_arrays = tuple(filter(lambda x: isinstance(x, Array), flat_deps)) + if len(dak_arrays) == 0: + raise TypeError( + "at least one argument passed to map_partitions " + "should be a dask_awkward.Array collection." + ) + in_npartitions = dak_arrays[0].npartitions + in_divisions = dak_arrays[0].divisions + + if output_divisions is not None: + if output_divisions == 1: + in_divisions = flat_deps[0].divisions + else: + in_divisions = tuple(map(lambda x: x * output_divisions, in_divisions)) + dak_cache[name] = (hlg, meta, in_divisions, in_npartitions) if output_divisions is not None: - if output_divisions == 1: - new_divisions = flat_deps[0].divisions - else: - new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions)) return new_array_object( hlg, name=name, meta=meta, - divisions=new_divisions, + divisions=in_divisions, ) else: - return new_array_object( - hlg, - name=name, - meta=meta, - npartitions=in_npartitions, - ) + return new_array_object(hlg, name=name, meta=meta, npartitions=in_npartitions) def _chunk_reducer_non_positional( From 14292d7e7ab43e7be75bc1cecbefa3eb661ab270 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Fri, 23 Feb 2024 12:49:35 -0500 Subject: [PATCH 5/9] add cachetools dep --- pyproject.toml | 1 + src/dask_awkward/lib/core.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d92511d2..0f1732a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ classifiers = [ dependencies = [ "awkward >=2.5.1", "dask >=2023.04.0", + "cacheools", "typing_extensions >=4.8.0", ] dynamic = ["version"] diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index dbd32052..5fac7c2c 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -755,7 +755,7 @@ def __reduce__(self): def fields(self) -> list[str]: if self._meta is None: raise TypeError("metadata is missing; cannot determine fields.") - return ak.fields(self._meta) + return getattr(self._meta, "fields", None) or [] @property def layout(self) -> Any: From 5e379ff35e2d9779cb4dc37f70f9bb65040cf48d Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Fri, 23 Feb 2024 13:10:43 -0500 Subject: [PATCH 6/9] remove spare line --- src/dask_awkward/lib/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index 5fac7c2c..057087f8 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -440,7 +440,6 @@ def __repr__(self) -> str: # pragma: no cover return self.__str__() def __str__(self) -> str: - self._meta # force updating from any staged ops if self.known_value is not None: return ( f"dask.awkward<{key_split(self.name)}, " From e5ac2be0c653a44e85a3a19e5d95b7a96b4eac98 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Fri, 23 Feb 2024 13:11:20 -0500 Subject: [PATCH 7/9] sp --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0f1732a9..6c57934a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ classifiers = [ dependencies = [ "awkward >=2.5.1", "dask >=2023.04.0", - "cacheools", + "cachetools", "typing_extensions >=4.8.0", ] dynamic = ["version"] From 9e396118a91ffac84235d56cc87f45b0f70eb3e9 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Sun, 3 Mar 2024 13:31:38 -0500 Subject: [PATCH 8/9] simplify fix --- src/dask_awkward/lib/core.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index ff992f63..5202eabb 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -1924,17 +1924,14 @@ def _map_partitions( token = token or tokenize(fn, *args, meta is not None and meta.typestr, **kwargs) label = hyphenize(label or funcname(fn)) name = f"{label}-{token}" + deps = [a for a in args if is_dask_collection(a)] + [ + v for v in kwargs.values() if is_dask_collection(v) + ] + dak_arrays = tuple(filter(lambda x: isinstance(x, Array), deps)) if name in dak_cache: - (hlg, meta, new_divisions, in_npartitions) = dak_cache[name] + hlg, meta = dak_cache[name] else: - - deps = [a for a in args if is_dask_collection(a)] + [ - v for v in kwargs.values() if is_dask_collection(v) - ] - - dak_arrays = tuple(filter(lambda x: isinstance(x, Array), deps)) - lay = partitionwise_layer( fn, name, @@ -1956,16 +1953,16 @@ def _map_partitions( "at least one argument passed to map_partitions " "should be a dask_awkward.Array collection." ) - in_npartitions = dak_arrays[0].npartitions - in_divisions = dak_arrays[0].divisions - if output_divisions is not None: - if output_divisions == 1: - new_divisions = dak_arrays[0].divisions - else: - new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions)) + dak_cache[name] = hlg, meta + in_npartitions = dak_arrays[0].npartitions + in_divisions = dak_arrays[0].divisions + if output_divisions is not None: + if output_divisions == 1: + new_divisions = dak_arrays[0].divisions else: - new_divisions = in_divisions - dak_cache[name] = (hlg, meta, new_divisions, in_npartitions) + new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions)) + else: + new_divisions = in_divisions if output_divisions is not None: return new_array_object( From 6e5a2ee630202a09ea7e235f2b80937c9d3dfc4a Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Mon, 4 Mar 2024 09:15:10 -0500 Subject: [PATCH 9/9] Don't tokenize on meta --- src/dask_awkward/lib/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index 5202eabb..592d27aa 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -1921,7 +1921,7 @@ def _map_partitions( will not be traversed to extract all dask collections, except those in the first dimension of args or kwargs. """ - token = token or tokenize(fn, *args, meta is not None and meta.typestr, **kwargs) + token = token or tokenize(fn, *args, output_divisions, **kwargs) label = hyphenize(label or funcname(fn)) name = f"{label}-{token}" deps = [a for a in args if is_dask_collection(a)] + [