From 92d887b4bb96381eccfe746b19588ad2bdbbfc55 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 21:51:16 +0000 Subject: [PATCH 1/7] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.5.7 → v0.6.1](https://github.com/astral-sh/ruff-pre-commit/compare/v0.5.7...v0.6.1) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 76f38f41..55ef9654 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: - --target-version=py312 - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.7 + rev: v0.6.1 hooks: - id: ruff From d1e67a4992652e6ce92b0fc77dec214f492b85af Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Tue, 20 Aug 2024 10:01:55 -0400 Subject: [PATCH 2/7] fixes --- docs/examples/20211111.ipynb | 3 +- docs/examples/io-tutorial/io-00-basic.ipynb | 100 +++++++----------- .../examples/io-tutorial/io-01-advanced.ipynb | 5 +- tests/test_core.py | 10 +- 4 files changed, 46 insertions(+), 72 deletions(-) diff --git a/docs/examples/20211111.ipynb b/docs/examples/20211111.ipynb index 43d6e4be..12e2034e 100644 --- a/docs/examples/20211111.ipynb +++ b/docs/examples/20211111.ipynb @@ -21,7 +21,6 @@ "metadata": {}, "outputs": [], "source": [ - "import dask_awkward as dak\n", "import dask_awkward.data as dakd" ] }, @@ -408,7 +407,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.8" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/docs/examples/io-tutorial/io-00-basic.ipynb b/docs/examples/io-tutorial/io-00-basic.ipynb index b3e98b49..5356c60c 100644 --- a/docs/examples/io-tutorial/io-00-basic.ipynb +++ b/docs/examples/io-tutorial/io-00-basic.ipynb @@ -23,47 +23,21 @@ "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "
package versions:\n",
-       "
\n" - ], - "text/plain": [ - "package versions:\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
awkward:       2.5.2\n",
-       "
\n" - ], - "text/plain": [ - "awkward: \u001b[1;36m2.5\u001b[0m.\u001b[1;36m2\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
dask-awkward:  2024.1.3.dev6+gda37bea\n",
-       "
\n" - ], - "text/plain": [ - "dask-awkward: \u001b[1;36m2024.1\u001b[0m.\u001b[1;36m3.\u001b[0mdev6+gda37bea\n" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stdout", + "output_type": "stream", + "text": [ + "package versions:\n", + "awkward: 2.6.7\n", + "dask-awkward: 2024.3.1.dev50+gb593f87.d20240522\n" + ] } ], "source": [ "from __future__ import annotations\n", + "\n", + "import os\n", + "\n", + "import numpy as np\n", "import awkward\n", "import dask_awkward\n", "print(\"package versions:\")\n", @@ -345,13 +319,13 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "File \u001b[0;32m~/software/repos/dask-awkward/src/dask_awkward/lib/core.py:1508\u001b[0m, in \u001b[0;36mArray.__getattr__\u001b[0;34m(self, attr)\u001b[0m\n\u001b[1;32m 1507\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1508\u001b[0m cls_method \u001b[38;5;241m=\u001b[39m \u001b[43mgetattr_static\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_meta\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n", - "File \u001b[0;32m~/.pyenv/versions/3.11.7/lib/python3.11/inspect.py:1853\u001b[0m, in \u001b[0;36mgetattr_static\u001b[0;34m(obj, attr, default)\u001b[0m\n\u001b[1;32m 1852\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m default\n\u001b[0;32m-> 1853\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(attr)\n", + "File \u001b[0;32m~/code/dask-awkward/src/dask_awkward/lib/core.py:1578\u001b[0m, in \u001b[0;36mArray.__getattr__\u001b[0;34m(self, attr)\u001b[0m\n\u001b[1;32m 1577\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1578\u001b[0m cls_method \u001b[38;5;241m=\u001b[39m \u001b[43mgetattr_static\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_meta\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1579\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n", + "File \u001b[0;32m~/conda/envs/py310/lib/python3.10/inspect.py:1777\u001b[0m, in \u001b[0;36mgetattr_static\u001b[0;34m(obj, attr, default)\u001b[0m\n\u001b[1;32m 1776\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m default\n\u001b[0;32m-> 1777\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(attr)\n", "\u001b[0;31mAttributeError\u001b[0m: distance", "\nDuring handling of the above exception, another exception occurred:\n", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mdataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscoring\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdistance\u001b[49m\n", - "File \u001b[0;32m~/software/repos/dask-awkward/src/dask_awkward/lib/core.py:1510\u001b[0m, in \u001b[0;36mArray.__getattr__\u001b[0;34m(self, attr)\u001b[0m\n\u001b[1;32m 1508\u001b[0m cls_method \u001b[38;5;241m=\u001b[39m getattr_static(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_meta, attr)\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n\u001b[0;32m-> 1510\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mattr\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m not in fields.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1511\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1512\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(cls_method, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_dask_get\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", + "File \u001b[0;32m~/code/dask-awkward/src/dask_awkward/lib/core.py:1580\u001b[0m, in \u001b[0;36mArray.__getattr__\u001b[0;34m(self, attr)\u001b[0m\n\u001b[1;32m 1578\u001b[0m cls_method \u001b[38;5;241m=\u001b[39m getattr_static(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_meta, attr)\n\u001b[1;32m 1579\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n\u001b[0;32m-> 1580\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mattr\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m not in fields.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1581\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1582\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(cls_method, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_dask_get\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", "\u001b[0;31mAttributeError\u001b[0m: distance not in fields." ] } @@ -370,7 +344,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "id": "63d11eb1-f822-4fc9-ae0b-c10fb6c8ea32", "metadata": {}, "outputs": [], @@ -390,7 +364,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 11, "id": "ad88a084-6d83-4eb7-a4a8-7befe58543d5", "metadata": {}, "outputs": [], @@ -400,7 +374,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 12, "id": "532776bb-0789-45d0-9bd8-d108d5143f1a", "metadata": { "scrolled": true @@ -412,7 +386,7 @@ "dask.awkward" ] }, - "execution_count": 16, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -431,7 +405,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 13, "id": "fa5d00ee-2ec1-455e-b0e8-4c64f6e8d36a", "metadata": {}, "outputs": [], @@ -449,7 +423,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 14, "id": "a59ea8ad-8ca6-444c-86cd-a4a4d9fc853d", "metadata": {}, "outputs": [], @@ -459,7 +433,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 15, "id": "60ef65b6-793e-40df-b2fa-f8c74b2ee8d0", "metadata": {}, "outputs": [ @@ -469,7 +443,7 @@ "dask.awkward" ] }, - "execution_count": 19, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -505,14 +479,14 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 16, "id": "62fcb593-63d5-4444-9d26-d0e23258f501", "metadata": {}, "outputs": [], "source": [ "dataset = dak.from_parquet(pq_dir)\n", "free_throws = dak.str.match_substring(dataset.scoring.basket, \"freethrow\")\n", - "distances = dataset.scoring.distance[free_throws == False]\n", + "distances = dataset.scoring.distance[np.equal(free_throws, False)]\n", "result = dak.mean(distances, axis=1)" ] }, @@ -526,7 +500,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 17, "id": "840b3ebe-1454-4dca-bee0-50a31f9c0df8", "metadata": { "scrolled": true @@ -535,11 +509,11 @@ { "data": { "text/plain": [ - "{'from-parquet-b7916bd949c3744cf0ec38dea00d0bd6': frozenset({'scoring.basket',\n", + "{'from-parquet-ab79c1929a2f8819e9ef6b725d844f8b': frozenset({'scoring.basket',\n", " 'scoring.distance'})}" ] }, - "execution_count": 21, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -562,20 +536,20 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 18, "id": "13914bf9-1f45-4860-8dc7-ec8eeb746bc0", "metadata": {}, "outputs": [], "source": [ "dataset = dak.from_json(os.path.join(\"data\", \"json\"))\n", "free_throws = dak.str.match_substring(dataset.scoring.basket, \"freethrow\")\n", - "distances = dataset.scoring.distance[free_throws == False]\n", + "distances = dataset.scoring.distance[np.equal(free_throws, False)]\n", "result = dak.mean(distances, axis=1)" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 19, "id": "1bc6e94b-ee5e-42d1-b789-6f80859b1d64", "metadata": { "scrolled": true @@ -584,11 +558,11 @@ { "data": { "text/plain": [ - "{'from-json-files-6eebaf87f3a09a08c1234137dd381b61': frozenset({'scoring.basket',\n", + "{'from-json-files-3542a860e83d7f93e632ec19911d7030': frozenset({'scoring.basket',\n", " 'scoring.distance'})}" ] }, - "execution_count": 23, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -611,7 +585,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 20, "id": "e444fa35-03ee-4292-8730-490dacd145fb", "metadata": {}, "outputs": [], @@ -626,7 +600,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 21, "id": "146a84b4-26ce-45c5-ad16-9c8967b60214", "metadata": {}, "outputs": [ @@ -642,7 +616,7 @@ " 'distance': {'type': 'number'}}}}}}" ] }, - "execution_count": 25, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -661,7 +635,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 22, "id": "f2d75df4-c8a7-4abd-942c-f1e94c124ec7", "metadata": {}, "outputs": [], @@ -697,7 +671,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/docs/examples/io-tutorial/io-01-advanced.ipynb b/docs/examples/io-tutorial/io-01-advanced.ipynb index 509f37ee..938c96ab 100644 --- a/docs/examples/io-tutorial/io-01-advanced.ipynb +++ b/docs/examples/io-tutorial/io-01-advanced.ipynb @@ -47,6 +47,7 @@ "source": [ "from __future__ import annotations\n", "\n", + "import os\n", "from typing import Any\n", "\n", "import awkward as ak\n", @@ -57,7 +58,7 @@ "class Ignore0ParquetReader(ColumnProjectionMixin):\n", " def __init__(\n", " self,\n", - " form: Form,\n", + " form: ak.forms.Form,\n", " report: bool = False,\n", " allowed_exceptions: tuple[type[BaseException], ...] = (OSError,),\n", " columns: list[str] | None = None,\n", @@ -347,7 +348,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/tests/test_core.py b/tests/test_core.py index 8dd82b9f..08a87def 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -384,13 +384,13 @@ def test_to_meta(daa: Array) -> None: def test_record_str(daa: Array) -> None: r = daa[0] - assert type(r) == dak.Record + assert type(r) is dak.Record assert str(r) == "dask.awkward" def test_record_to_delayed(daa: Array) -> None: r = daa[0] - assert type(r) == dak.Record + assert type(r) is dak.Record d = r.to_delayed() x = r.compute().tolist() y = d.compute().tolist() @@ -399,7 +399,7 @@ def test_record_to_delayed(daa: Array) -> None: def test_record_fields(daa: Array) -> None: r = daa[0] - assert type(r) == dak.Record + assert type(r) is dak.Record r._meta = None with pytest.raises(TypeError, match="metadata is missing"): assert not r.fields @@ -407,7 +407,7 @@ def test_record_fields(daa: Array) -> None: def test_record_dir(daa: Array) -> None: r = daa["points"][0][0] - assert type(r) == dak.Record + assert type(r) is dak.Record d = dir(r) for f in r.fields: assert f in d @@ -418,7 +418,7 @@ def test_record_dir(daa: Array) -> None: # import pickle # r = daa[0] -# assert type(r) == dak.Record +# assert type(r) is dak.Record # assert isinstance(r._meta, ak.Record) # dumped = pickle.dumps(r) From 838ba31af57302b24d922d4033178e2d63f84deb Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Tue, 20 Aug 2024 10:58:32 -0400 Subject: [PATCH 3/7] Fix lint --- src/dask_awkward/lib/core.py | 56 ++++++++++++++------------------ src/dask_awkward/lib/optimize.py | 24 +++++++------- 2 files changed, 36 insertions(+), 44 deletions(-) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index e52e43af..19ae2dde 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -2267,50 +2267,42 @@ def non_trivial_reduction( else: prepared_array = array - chunked_fn = _chunk_reducer_non_positional - tree_node_fn = _chunk_reducer_non_positional - concat_fn = _concat_reducer_non_positional - finalize_fn = _finalise_reducer_non_positional - - chunked_kwargs = { - "reducer": reducer, - "is_axis_none": axis is None, - "mask_identity": mask_identity, - } - tree_node_kwargs = { - "reducer": combiner, - "is_axis_none": axis is None, - "mask_identity": mask_identity, - } - - concat_kwargs = {"is_axis_none": axis is None} - finalize_kwargs = { - "reducer": combiner, - "mask_identity": mask_identity, - "keepdims": keepdims, - "is_axis_none": axis is None, - } - from dask_awkward.layers import AwkwardTreeReductionLayer token = token or tokenize( array, reducer, + combiner, label, dtype, split_every, - chunked_kwargs, - tree_node_kwargs, - concat_kwargs, - finalize_kwargs, + axis, + mask_identity, + keepdims, ) name_tree_node = f"{label}-tree-node-{token}" name_finalize = f"{label}-finalize-{token}" - chunked_fn = partial(chunked_fn, **chunked_kwargs) - tree_node_fn = partial(tree_node_fn, **tree_node_kwargs) - concat_fn = partial(concat_fn, **concat_kwargs) - finalize_fn = partial(finalize_fn, **finalize_kwargs) + chunked_fn = partial( + _chunk_reducer_non_positional, + reducer=reducer, + is_axis_none=axis is None, + mask_identity=mask_identity, + ) + tree_node_fn = partial( + _chunk_reducer_non_positional, + reducer=combiner, + is_axis_none=axis is None, + mask_identity=mask_identity, + ) + concat_fn = partial(_concat_reducer_non_positional, is_axis_none=axis is None) + finalize_fn = partial( + _finalise_reducer_non_positional, + reducer=combiner, + is_axis_none=axis is None, + keepdims=keepdims, + mask_identity=mask_identity, + ) if split_every is None: split_every = dask.config.get("awkward.aggregation.split-every", 8) diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index a1ab9d6a..d8691195 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -335,32 +335,32 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG outkey = chain[-1] layer0 = dsk.layers[chain[0]] outlayer = layers[outkey] - numblocks = [nb[0] for nb in layer0.numblocks.values() if nb[0] is not None][0] # type: ignore + numblocks = [nb[0] for nb in layer0.numblocks.values() if nb[0] is not None][0] deps[outkey] = deps[chain[0]] # type: ignore [deps.pop(ch) for ch in chain[:-1]] # type: ignore - subgraph = layer0.dsk.copy() # type: ignore - indices = list(layer0.indices) # type: ignore + subgraph = layer0.dsk.copy() + indices = list(layer0.indices) parent = chain[0] - outlayer.io_deps = layer0.io_deps # type: ignore + outlayer.io_deps = layer0.io_deps for chain_member in chain[1:]: layer = dsk.layers[chain_member] - for k in layer.io_deps: # type: ignore - outlayer.io_deps[k] = layer.io_deps[k] # type: ignore - func, *args = layer.dsk[chain_member] # type: ignore + for k in layer.io_deps: + outlayer.io_deps[k] = layer.io_deps[k] + func, *args = layer.dsk[chain_member] args2 = _recursive_replace(args, layer, parent, indices) subgraph[chain_member] = (func,) + tuple(args2) parent = chain_member - outlayer.numblocks = {i[0]: (numblocks,) for i in indices if i[1] is not None} # type: ignore - outlayer.dsk = subgraph # type: ignore + outlayer.numblocks = {i[0]: (numblocks,) for i in indices if i[1] is not None} + outlayer.dsk = subgraph if hasattr(outlayer, "_dims"): del outlayer._dims - outlayer.indices = tuple( # type: ignore + outlayer.indices = tuple( (i[0], (".0",) if i[1] is not None else None) for i in indices ) - outlayer.output_indices = (".0",) # type: ignore - outlayer.inputs = getattr(layer0, "inputs", set()) # type: ignore + outlayer.output_indices = (".0",) + outlayer.inputs = getattr(layer0, "inputs", set()) if hasattr(outlayer, "_cached_dict"): del outlayer._cached_dict # reset, since original can be mutated return HighLevelGraph(layers, deps) From a11da62d0d6f834ff4c8a3bed1fc369484a11d86 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Tue, 20 Aug 2024 11:08:08 -0400 Subject: [PATCH 4/7] try again --- src/dask_awkward/lib/optimize.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index d8691195..d6c1c08a 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -8,7 +8,7 @@ import dask.config from awkward.typetracer import touch_data -from dask.blockwise import fuse_roots, optimize_blockwise +from dask.blockwise import Blockwise, fuse_roots, optimize_blockwise from dask.core import flatten from dask.highlevelgraph import HighLevelGraph from dask.local import get_sync @@ -333,7 +333,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG # outputs are the outputs of chain[-1] # .dsk is composed from the .dsk of each layer outkey = chain[-1] - layer0 = dsk.layers[chain[0]] + layer0 = cast(Blockwise, dsk.layers[chain[0]]) outlayer = layers[outkey] numblocks = [nb[0] for nb in layer0.numblocks.values() if nb[0] is not None][0] deps[outkey] = deps[chain[0]] # type: ignore @@ -347,7 +347,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG for chain_member in chain[1:]: layer = dsk.layers[chain_member] for k in layer.io_deps: - outlayer.io_deps[k] = layer.io_deps[k] + outlayer.io_deps[k] = layer.io_deps[k] # type: ignore func, *args = layer.dsk[chain_member] args2 = _recursive_replace(args, layer, parent, indices) subgraph[chain_member] = (func,) + tuple(args2) From 38e95c5e83bc211485aa1aab33e84bd557359954 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Wed, 21 Aug 2024 10:18:07 -0400 Subject: [PATCH 5/7] grr --- src/dask_awkward/lib/optimize.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index d6c1c08a..3a14e8ae 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -339,28 +339,30 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG deps[outkey] = deps[chain[0]] # type: ignore [deps.pop(ch) for ch in chain[:-1]] # type: ignore - subgraph = layer0.dsk.copy() + subgraph = layer0.dsk.copy() # mypy: ignore indices = list(layer0.indices) parent = chain[0] - outlayer.io_deps = layer0.io_deps + outlayer.io_deps = layer0.io_deps # mypy: ignore for chain_member in chain[1:]: layer = dsk.layers[chain_member] - for k in layer.io_deps: + for k in layer.io_deps: # mypy: ignore outlayer.io_deps[k] = layer.io_deps[k] # type: ignore - func, *args = layer.dsk[chain_member] + func, *args = layer.dsk[chain_member] # mypy: ignore args2 = _recursive_replace(args, layer, parent, indices) subgraph[chain_member] = (func,) + tuple(args2) parent = chain_member - outlayer.numblocks = {i[0]: (numblocks,) for i in indices if i[1] is not None} - outlayer.dsk = subgraph + outlayer.numblocks = { + i[0]: (numblocks,) for i in indices if i[1] is not None + } # mypy: ignore + outlayer.dsk = subgraph # mypy: ignore if hasattr(outlayer, "_dims"): del outlayer._dims - outlayer.indices = tuple( + outlayer.indices = tuple( # mypy: ignore (i[0], (".0",) if i[1] is not None else None) for i in indices ) - outlayer.output_indices = (".0",) - outlayer.inputs = getattr(layer0, "inputs", set()) + outlayer.output_indices = (".0",) # mypy: ignore + outlayer.inputs = getattr(layer0, "inputs", set()) # mypy: ignore if hasattr(outlayer, "_cached_dict"): del outlayer._cached_dict # reset, since original can be mutated return HighLevelGraph(layers, deps) From 51864e87f1772de846a7a342d211650230351eb1 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Wed, 21 Aug 2024 11:01:38 -0400 Subject: [PATCH 6/7] kick From dbd496431448fa8846ab7e3decf1e9587e7889a9 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Wed, 21 Aug 2024 11:13:15 -0400 Subject: [PATCH 7/7] skip whole function --- src/dask_awkward/lib/optimize.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 3a14e8ae..6ad2e132 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -4,7 +4,7 @@ import logging import warnings from collections.abc import Iterable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, cast, no_type_check import dask.config from awkward.typetracer import touch_data @@ -244,6 +244,7 @@ def _mock_output(layer): return new_layer +@no_type_check def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph: """Smush chains of blockwise layers into a single layer. @@ -336,8 +337,8 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG layer0 = cast(Blockwise, dsk.layers[chain[0]]) outlayer = layers[outkey] numblocks = [nb[0] for nb in layer0.numblocks.values() if nb[0] is not None][0] - deps[outkey] = deps[chain[0]] # type: ignore - [deps.pop(ch) for ch in chain[:-1]] # type: ignore + deps[outkey] = deps[chain[0]] + [deps.pop(ch) for ch in chain[:-1]] subgraph = layer0.dsk.copy() # mypy: ignore indices = list(layer0.indices) @@ -347,7 +348,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG for chain_member in chain[1:]: layer = dsk.layers[chain_member] for k in layer.io_deps: # mypy: ignore - outlayer.io_deps[k] = layer.io_deps[k] # type: ignore + outlayer.io_deps[k] = layer.io_deps[k] func, *args = layer.dsk[chain_member] # mypy: ignore args2 = _recursive_replace(args, layer, parent, indices) subgraph[chain_member] = (func,) + tuple(args2)