From 0e411aef4ebd1981814d10f9b1525a6291f79de5 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 5 Dec 2023 15:02:09 -0800 Subject: [PATCH 01/19] Split out graph Expr code from Dataframe Expr code Closes #142 Supercedes #158 --- dask_expr/_core.py | 642 +++++++++++++++++++++++++++++++++++++++++++++ dask_expr/_expr.py | 639 +------------------------------------------- 2 files changed, 649 insertions(+), 632 deletions(-) create mode 100644 dask_expr/_core.py diff --git a/dask_expr/_core.py b/dask_expr/_core.py new file mode 100644 index 00000000..6b2f6f23 --- /dev/null +++ b/dask_expr/_core.py @@ -0,0 +1,642 @@ +from __future__ import annotations + +import functools +import os +from collections.abc import Generator + +import dask +import pandas as pd +import toolz +from dask.dataframe.core import is_dataframe_like, is_index_like, is_series_like +from dask.utils import funcname, import_required, is_arraylike + +from dask_expr._util import _BackendData, _tokenize_deterministic, _tokenize_partial + + +class Expr: + _parameters = [] + _defaults = {} + + def __init__(self, *args, **kwargs): + operands = list(args) + for parameter in type(self)._parameters[len(operands) :]: + try: + operands.append(kwargs.pop(parameter)) + except KeyError: + operands.append(type(self)._defaults[parameter]) + assert not kwargs, kwargs + self.operands = operands + if self._required_attribute: + dep = next(iter(self.dependencies()))._meta + if not hasattr(dep, self._required_attribute): + # Raise a ValueError instead of AttributeError to + # avoid infinite recursion + raise ValueError(f"{dep} has no attribute {self._required_attribute}") + + @property + def _required_attribute(self) -> str: + # Specify if the first `dependency` must support + # a specific attribute for valid behavior. + return None + + def __str__(self): + s = ", ".join( + str(param) + "=" + str(operand) + for param, operand in zip(self._parameters, self.operands) + if isinstance(operand, Expr) or operand != self._defaults.get(param) + ) + return f"{type(self).__name__}({s})" + + def __repr__(self): + return str(self) + + def _tree_repr_lines(self, indent=0, recursive=True): + header = funcname(type(self)) + ":" + lines = [] + for i, op in enumerate(self.operands): + if isinstance(op, Expr): + if recursive: + lines.extend(op._tree_repr_lines(2)) + else: + try: + param = self._parameters[i] + default = self._defaults[param] + except (IndexError, KeyError): + param = self._parameters[i] if i < len(self._parameters) else "" + default = "--no-default--" + + if isinstance(op, _BackendData): + op = op._data + + # TODO: this stuff is pandas-specific + if isinstance(op, pd.core.base.PandasObject): + op = "" + elif is_dataframe_like(op): + op = "" + elif is_index_like(op): + op = "" + elif is_series_like(op): + op = "" + elif is_arraylike(op): + op = "" + + if repr(op) != repr(default): + if param: + header += f" {param}={repr(op)}" + else: + header += repr(op) + lines = [header] + lines + lines = [" " * indent + line for line in lines] + + return lines + + def tree_repr(self): + return os.linesep.join(self._tree_repr_lines()) + + def pprint(self): + for line in self._tree_repr_lines(): + print(line) + + def __hash__(self): + return hash(self._name) + + def __reduce__(self): + if dask.config.get("dask-expr-no-serialize", False): + raise RuntimeError(f"Serializing a {type(self)} object") + return type(self), tuple(self.operands) + + def _depth(self): + """Depth of the expression tree + + Returns + ------- + depth: int + """ + if not self.dependencies(): + return 1 + else: + return max(expr._depth() for expr in self.dependencies()) + 1 + + def operand(self, key): + # Access an operand unambiguously + # (e.g. if the key is reserved by a method/property) + return self.operands[type(self)._parameters.index(key)] + + def dependencies(self): + # Dependencies are `Expr` operands only + return [operand for operand in self.operands if isinstance(operand, Expr)] + + def _task(self, index: int): + """The task for the i'th partition + + Parameters + ---------- + index: + The index of the partition of this dataframe + + Examples + -------- + >>> class Add(Expr): + ... def _task(self, i): + ... return (operator.add, (self.left._name, i), (self.right._name, i)) + + Returns + ------- + task: + The Dask task to compute this partition + + See Also + -------- + Expr._layer + """ + raise NotImplementedError( + "Expressions should define either _layer (full dictionary) or _task" + " (single task). This expression type defines neither" + ) + + def _layer(self) -> dict: + """The graph layer added by this expression + + Examples + -------- + >>> class Add(Expr): + ... def _layer(self): + ... return { + ... (self._name, i): (operator.add, (self.left._name, i), (self.right._name, i)) + ... for i in range(self.npartitions) + ... } + + Returns + ------- + layer: dict + The Dask task graph added by this expression + + See Also + -------- + Expr._task + Expr.__dask_graph__ + """ + + return {(self._name, i): self._task(i) for i in range(self.npartitions)} + + def rewrite(self, kind: str): + """Rewrite an expression + + This leverages the ``._{kind}_down`` and ``._{kind}_up`` + methods defined on each class + + Returns + ------- + expr: + output expression + changed: + whether or not any change occured + """ + expr = self + down_name = f"_{kind}_down" + up_name = f"_{kind}_up" + while True: + _continue = False + + # Rewrite this node + if down_name in expr.__dir__(): + out = getattr(expr, down_name)() + if out is None: + out = expr + if not isinstance(out, Expr): + return out + if out._name != expr._name: + expr = out + continue + + # Allow children to rewrite their parents + for child in expr.dependencies(): + if up_name in child.__dir__(): + out = getattr(child, up_name)(expr) + if out is None: + out = expr + if not isinstance(out, Expr): + return out + if out is not expr and out._name != expr._name: + expr = out + _continue = True + break + + if _continue: + continue + + # Rewrite all of the children + new_operands = [] + changed = False + for operand in expr.operands: + if isinstance(operand, Expr): + new = operand.rewrite(kind=kind) + if new._name != operand._name: + changed = True + else: + new = operand + new_operands.append(new) + + if changed: + expr = type(expr)(*new_operands) + continue + else: + break + + return expr + + def simplify(self): + """Simplify an expression + + This leverages the ``._simplify_down`` and ``._simplify_up`` + methods defined on each class + + Returns + ------- + expr: + output expression + changed: + whether or not any change occured + """ + return self.rewrite(kind="simplify") + + def _simplify_down(self): + return + + def _simplify_up(self, parent): + return + + def lower_once(self): + expr = self + + # Lower this node + out = expr._lower() + if out is None: + out = expr + if not isinstance(out, Expr): + return out + + # Lower all children + new_operands = [] + changed = False + for operand in out.operands: + if isinstance(operand, Expr): + new = operand.lower_once() + if new._name != operand._name: + changed = True + else: + new = operand + new_operands.append(new) + + if changed: + out = type(out)(*new_operands) + + return out + + def lower_completely(self) -> Expr: + """Lower an expression completely + + This calls the ``lower_once`` method in a loop + until nothing changes. This function does not + apply any other optimizations (like ``simplify``). + + Returns + ------- + expr: + output expression + + See Also + -------- + Expr.lower_once + Expr._lower + """ + # Lower until nothing changes + expr = self + while True: + new = expr.lower_once() + if new._name == expr._name: + break + expr = new + return expr + + def _lower(self): + return + + def _remove_operations(self, frame, remove_ops, skip_ops=None): + """Searches for operations that we have to push up again to avoid + the duplication of branches that are doing the same. + + Parameters + ---------- + frame: Expression that we will search. + remove_ops: Ops that we will remove to push up again. + skip_ops: Ops that were introduced and that we want to ignore. + + Returns + ------- + tuple of the new expression and the operations that we removed. + """ + + operations, ops_to_push_up = [], [] + frame_base = frame + combined_ops = remove_ops if skip_ops is None else remove_ops + skip_ops + while isinstance(frame, combined_ops): + # Have to respect ops that were injected while lowering or filters + if isinstance(frame, remove_ops): + ops_to_push_up.append(frame.operands[1]) + frame = frame.frame + break + else: + operations.append((type(frame), frame.operands[1:])) + frame = frame.frame + + if len(ops_to_push_up) > 0: + # Remove the projections but build the remaining things back up + for op_type, operands in reversed(operations): + frame = op_type(frame, *operands) + return frame, ops_to_push_up + else: + return frame_base, [] + + @functools.cached_property + def _name(self): + return ( + funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands) + ) + + @property + def _meta(self): + raise NotImplementedError() + + def __dask_graph__(self): + """Traverse expression tree, collect layers""" + stack = [self] + seen = set() + layers = [] + while stack: + expr = stack.pop() + + if expr._name in seen: + continue + seen.add(expr._name) + + layers.append(expr._layer()) + for operand in expr.dependencies(): + stack.append(operand) + + return toolz.merge(layers) + + def __dask_keys__(self): + return [(self._name, i) for i in range(self.npartitions)] + + def substitute(self, old, new) -> Expr: + """Substitute a specific term within the expression + + Note that replacing non-`Expr` terms may produce + unexpected results, and is not recommended. + Substituting boolean values is not allowed. + + Parameters + ---------- + old: + Old term to find and replace. + new: + New term to replace instances of `old` with. + + Examples + -------- + >>> (df + 10).substitute(10, 20) + df + 20 + """ + + # Check if we are replacing a literal + if isinstance(old, Expr): + substitute_literal = False + if self._name == old._name: + return new + else: + substitute_literal = True + if isinstance(old, bool): + raise TypeError("Arguments to `substitute` cannot be bool.") + + new_exprs = [] + update = False + for operand in self.operands: + if isinstance(operand, Expr): + val = operand.substitute(old, new) + if operand._name != val._name: + update = True + new_exprs.append(val) + elif ( + "Fused" in type(self).__name__ + and isinstance(operand, list) + and all(isinstance(op, Expr) for op in operand) + ): + # Special handling for `Fused`. + # We make no promise to dive through a + # list operand in general, but NEED to + # do so for the `Fused.exprs` operand. + val = [] + for op in operand: + val.append(op.substitute(old, new)) + if val[-1]._name != op._name: + update = True + new_exprs.append(val) + elif ( + substitute_literal + and not isinstance(operand, bool) + and isinstance(operand, type(old)) + and operand == old + ): + new_exprs.append(new) + update = True + else: + new_exprs.append(operand) + + if update: # Only recreate if something changed + return type(self)(*new_exprs) + return self + + def substitute_parameters(self, substitutions: dict) -> Expr: + """Substitute specific `Expr` parameters + + Parameters + ---------- + substitutions: + Mapping of parameter keys to new values. Keys that + are not found in ``self._parameters`` will be ignored. + """ + if not substitutions: + return self + + changed = False + new_operands = [] + for i, operand in enumerate(self.operands): + if i < len(self._parameters) and self._parameters[i] in substitutions: + new_operands.append(substitutions[self._parameters[i]]) + changed = True + else: + new_operands.append(operand) + if changed: + return type(self)(*new_operands) + return self + + def _find_similar_operations(self, root: Expr, ignore: list | None = None): + # Find operations with the same type and operands. + # Parameter keys specified by `ignore` will not be + # included in the operand comparison + alike = [ + op for op in root.find_operations(type(self)) if op._name != self._name + ] + if not alike: + # No other operations of the same type. Early return + return [] + + # Return subset of `alike` with the same "token" + token = _tokenize_partial(self, ignore) + return [item for item in alike if _tokenize_partial(item, ignore) == token] + + def _node_label_args(self): + """Operands to include in the node label by `visualize`""" + return self.dependencies() + + def _to_graphviz( + self, + rankdir="BT", + graph_attr=None, + node_attr=None, + edge_attr=None, + **kwargs, + ): + from dask.dot import label, name + + graphviz = import_required( + "graphviz", + "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` " + "python library and the `graphviz` system library.\n\n" + "Please either conda or pip install as follows:\n\n" + " conda install python-graphviz # either conda install\n" + " python -m pip install graphviz # or pip install and follow installation instructions", + ) + + graph_attr = graph_attr or {} + node_attr = node_attr or {} + edge_attr = edge_attr or {} + + graph_attr["rankdir"] = rankdir + node_attr["shape"] = "box" + node_attr["fontname"] = "helvetica" + + graph_attr.update(kwargs) + g = graphviz.Digraph( + graph_attr=graph_attr, + node_attr=node_attr, + edge_attr=edge_attr, + ) + + stack = [self] + seen = set() + dependencies = {} + while stack: + expr = stack.pop() + + if expr._name in seen: + continue + seen.add(expr._name) + + dependencies[expr] = set(expr.dependencies()) + for dep in expr.dependencies(): + stack.append(dep) + + cache = {} + for expr in dependencies: + expr_name = name(expr) + attrs = {} + + # Make node label + deps = [ + funcname(type(dep)) if isinstance(dep, Expr) else str(dep) + for dep in expr._node_label_args() + ] + _label = funcname(type(expr)) + if deps: + _label = f"{_label}({', '.join(deps)})" if deps else _label + node_label = label(_label, cache=cache) + + attrs.setdefault("label", str(node_label)) + attrs.setdefault("fontsize", "20") + g.node(expr_name, **attrs) + + for expr, deps in dependencies.items(): + expr_name = name(expr) + for dep in deps: + dep_name = name(dep) + g.edge(dep_name, expr_name) + + return g + + def visualize(self, filename="dask-expr.svg", format=None, **kwargs): + """ + Visualize the expression graph. + Requires ``graphviz`` to be installed. + + Parameters + ---------- + filename : str or None, optional + The name of the file to write to disk. If the provided `filename` + doesn't include an extension, '.png' will be used by default. + If `filename` is None, no file will be written, and the graph is + rendered in the Jupyter notebook only. + format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional + Format in which to write output file. Default is 'svg'. + **kwargs + Additional keyword arguments to forward to ``to_graphviz``. + """ + from dask.dot import graphviz_to_file + + g = self._to_graphviz(**kwargs) + graphviz_to_file(g, filename, format) + return g + + def walk(self) -> Generator[Expr]: + """Iterate through all expressions in the tree + + Returns + ------- + nodes + Generator of Expr instances in the graph. + Ordering is a depth-first search of the expression tree + """ + stack = [self] + seen = set() + while stack: + node = stack.pop() + if node._name in seen: + continue + seen.add(node._name) + + for dep in node.dependencies(): + stack.append(dep) + + yield node + + def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]: + """Search the expression graph for a specific operation type + + Parameters + ---------- + operation + The operation type to search for. + + Returns + ------- + nodes + Generator of `operation` instances. Ordering corresponds + to a depth-first search of the expression graph. + """ + assert ( + isinstance(operation, tuple) + and all(issubclass(e, Expr) for e in operation) + or issubclass(operation, Expr) + ), "`operation` must be`Expr` subclass)" + return (expr for expr in self.walk() if isinstance(expr, operation)) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index bd62117d..7e84c584 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -3,14 +3,12 @@ import functools import numbers import operator -import os from collections import defaultdict -from collections.abc import Generator, Mapping +from collections.abc import Mapping import dask import numpy as np import pandas as pd -import toolz from dask.base import normalize_token from dask.core import flatten from dask.dataframe import methods @@ -28,47 +26,22 @@ from dask.dataframe.rolling import CombinedOutput, _head_timedelta, overlap_chunk from dask.dataframe.utils import clear_known_categories, drop_by_shallow_copy from dask.typing import no_default -from dask.utils import M, apply, funcname, import_required, is_arraylike +from dask.utils import M, apply, funcname from tlz import merge_sorted, unique -from dask_expr._util import _BackendData, _tokenize_deterministic, _tokenize_partial +from dask_expr import _core as core +from dask_expr._util import _tokenize_deterministic, _tokenize_partial -replacement_rules = [] - -class Expr: +class Expr(core.Expr): """Primary class for all Expressions This mostly includes Dask protocols and various Pandas-like method definitions to make us look more like a DataFrame. """ - _parameters = [] - _defaults = {} _is_length_preserving = False - def __init__(self, *args, **kwargs): - operands = list(args) - for parameter in type(self)._parameters[len(operands) :]: - try: - operands.append(kwargs.pop(parameter)) - except KeyError: - operands.append(type(self)._defaults[parameter]) - assert not kwargs, kwargs - self.operands = operands - if self._required_attribute: - dep = next(iter(self.dependencies()))._meta - if not hasattr(dep, self._required_attribute): - # Raise a ValueError instead of AttributeError to - # avoid infinite recursion - raise ValueError(f"{dep} has no attribute {self._required_attribute}") - - @property - def _required_attribute(self) -> str: - # Specify if the first `dependency` must support - # a specific attribute for valid behavior. - return None - @functools.cached_property def ndim(self): meta = self._meta @@ -77,83 +50,12 @@ def ndim(self): except AttributeError: return 0 - def __str__(self): - s = ", ".join( - str(param) + "=" + str(operand) - for param, operand in zip(self._parameters, self.operands) - if isinstance(operand, Expr) or operand != self._defaults.get(param) - ) - return f"{type(self).__name__}({s})" - - def __repr__(self): - return str(self) - - def _tree_repr_lines(self, indent=0, recursive=True): - header = funcname(type(self)) + ":" - lines = [] - for i, op in enumerate(self.operands): - if isinstance(op, Expr): - if recursive: - lines.extend(op._tree_repr_lines(2)) - else: - try: - param = self._parameters[i] - default = self._defaults[param] - except (IndexError, KeyError): - param = self._parameters[i] if i < len(self._parameters) else "" - default = "--no-default--" - - if isinstance(op, _BackendData): - op = op._data - - if isinstance(op, pd.core.base.PandasObject): - op = "" - elif is_dataframe_like(op): - op = "" - elif is_index_like(op): - op = "" - elif is_series_like(op): - op = "" - elif is_arraylike(op): - op = "" - - if repr(op) != repr(default): - if param: - header += f" {param}={repr(op)}" - else: - header += repr(op) - lines = [header] + lines - lines = [" " * indent + line for line in lines] - - return lines - - def tree_repr(self): - return os.linesep.join(self._tree_repr_lines()) - - def pprint(self): - for line in self._tree_repr_lines(): - print(line) + def optimize(self, **kwargs): + return optimize(self, **kwargs) def __hash__(self): return hash(self._name) - def __reduce__(self): - if dask.config.get("dask-expr-no-serialize", False): - raise RuntimeError(f"Serializing a {type(self)} object") - return type(self), tuple(self.operands) - - def _depth(self): - """Depth of the expression tree - - Returns - ------- - depth: int - """ - if not self.dependencies(): - return 1 - else: - return max(expr._depth() for expr in self.dependencies()) + 1 - def __getattr__(self, key): try: return object.__getattribute__(self, key) @@ -183,211 +85,6 @@ def __getattr__(self, key): f"API function. Current API coverage is documented here: {link}." ) - def operand(self, key): - # Access an operand unambiguously - # (e.g. if the key is reserved by a method/property) - return self.operands[type(self)._parameters.index(key)] - - def dependencies(self): - # Dependencies are `Expr` operands only - return [operand for operand in self.operands if isinstance(operand, Expr)] - - def _task(self, index: int): - """The task for the i'th partition - - Parameters - ---------- - index: - The index of the partition of this dataframe - - Examples - -------- - >>> class Add(Expr): - ... def _task(self, i): - ... return (operator.add, (self.left._name, i), (self.right._name, i)) - - Returns - ------- - task: - The Dask task to compute this partition - - See Also - -------- - Expr._layer - """ - raise NotImplementedError( - "Expressions should define either _layer (full dictionary) or _task" - " (single task). This expression type defines neither" - ) - - def _layer(self) -> dict: - """The graph layer added by this expression - - Examples - -------- - >>> class Add(Expr): - ... def _layer(self): - ... return { - ... (self._name, i): (operator.add, (self.left._name, i), (self.right._name, i)) - ... for i in range(self.npartitions) - ... } - - Returns - ------- - layer: dict - The Dask task graph added by this expression - - See Also - -------- - Expr._task - Expr.__dask_graph__ - """ - - return {(self._name, i): self._task(i) for i in range(self.npartitions)} - - def rewrite(self, kind: str): - """Rewrite an expression - - This leverages the ``._{kind}_down`` and ``._{kind}_up`` - methods defined on each class - - Returns - ------- - expr: - output expression - changed: - whether or not any change occured - """ - expr = self - down_name = f"_{kind}_down" - up_name = f"_{kind}_up" - while True: - _continue = False - - # Rewrite this node - if down_name in expr.__dir__(): - out = getattr(expr, down_name)() - if out is None: - out = expr - if not isinstance(out, Expr): - return out - if out._name != expr._name: - expr = out - continue - - # Allow children to rewrite their parents - for child in expr.dependencies(): - if up_name in child.__dir__(): - out = getattr(child, up_name)(expr) - if out is None: - out = expr - if not isinstance(out, Expr): - return out - if out is not expr and out._name != expr._name: - expr = out - _continue = True - break - - if _continue: - continue - - # Rewrite all of the children - new_operands = [] - changed = False - for operand in expr.operands: - if isinstance(operand, Expr): - new = operand.rewrite(kind=kind) - if new._name != operand._name: - changed = True - else: - new = operand - new_operands.append(new) - - if changed: - expr = type(expr)(*new_operands) - continue - else: - break - - return expr - - def simplify(self): - """Simplify an expression - - This leverages the ``._simplify_down`` and ``._simplify_up`` - methods defined on each class - - Returns - ------- - expr: - output expression - changed: - whether or not any change occured - """ - return self.rewrite(kind="simplify") - - def _simplify_down(self): - return - - def _simplify_up(self, parent): - return - - def lower_once(self): - expr = self - - # Lower this node - out = expr._lower() - if out is None: - out = expr - if not isinstance(out, Expr): - return out - - # Lower all children - new_operands = [] - changed = False - for operand in out.operands: - if isinstance(operand, Expr): - new = operand.lower_once() - if new._name != operand._name: - changed = True - else: - new = operand - new_operands.append(new) - - if changed: - out = type(out)(*new_operands) - - return out - - def lower_completely(self) -> Expr: - """Lower an expression completely - - This calls the ``lower_once`` method in a loop - until nothing changes. This function does not - apply any other optimizations (like ``simplify``). - - Returns - ------- - expr: - output expression - - See Also - -------- - Expr.lower_once - Expr._lower - """ - # Lower until nothing changes - expr = self - while True: - new = expr.lower_once() - if new._name == expr._name: - break - expr = new - return expr - - def _lower(self): - return - def combine_similar( self, root: Expr | None = None, _cache: dict | None = None ) -> Expr: @@ -504,45 +201,6 @@ def _combine_similar_branches(self, root, remove_ops, skip_ops=None): common = common._simplify_down() or common return common - def _remove_operations(self, frame, remove_ops, skip_ops=None): - """Searches for operations that we have to push up again to avoid - the duplication of branches that are doing the same. - - Parameters - ---------- - frame: Expression that we will search. - remove_ops: Ops that we will remove to push up again. - skip_ops: Ops that were introduced and that we want to ignore. - - Returns - ------- - tuple of the new expression and the operations that we removed. - """ - - operations, ops_to_push_up = [], [] - frame_base = frame - combined_ops = remove_ops if skip_ops is None else remove_ops + skip_ops - while isinstance(frame, combined_ops): - # Have to respect ops that were injected while lowering or filters - if isinstance(frame, remove_ops): - ops_to_push_up.append(frame.operands[1]) - frame = frame.frame - break - else: - operations.append((type(frame), frame.operands[1:])) - frame = frame.frame - - if len(ops_to_push_up) > 0: - # Remove the projections but build the remaining things back up - for op_type, operands in reversed(operations): - frame = op_type(frame, *operands) - return frame, ops_to_push_up - else: - return frame_base, [] - - def optimize(self, **kwargs): - return optimize(self, **kwargs) - @property def index(self): return Index(self) @@ -791,12 +449,6 @@ def npartitions(self): else: return len(self.divisions) - 1 - @functools.cached_property - def _name(self): - return ( - funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands) - ) - @property def columns(self) -> list: try: @@ -808,283 +460,6 @@ def columns(self) -> list: def dtypes(self): return self._meta.dtypes - @property - def _meta(self): - raise NotImplementedError() - - def __dask_graph__(self): - """Traverse expression tree, collect layers""" - stack = [self] - seen = set() - layers = [] - while stack: - expr = stack.pop() - - if expr._name in seen: - continue - seen.add(expr._name) - - layers.append(expr._layer()) - for operand in expr.dependencies(): - stack.append(operand) - - return toolz.merge(layers) - - def __dask_keys__(self): - return [(self._name, i) for i in range(self.npartitions)] - - def substitute(self, old, new) -> Expr: - """Substitute a specific term within the expression - - Note that replacing non-`Expr` terms may produce - unexpected results, and is not recommended. - Substituting boolean values is not allowed. - - Parameters - ---------- - old: - Old term to find and replace. - new: - New term to replace instances of `old` with. - - Examples - -------- - >>> (df + 10).substitute(10, 20) - df + 20 - """ - - # Check if we are replacing a literal - if isinstance(old, Expr): - substitute_literal = False - if self._name == old._name: - return new - else: - substitute_literal = True - if isinstance(old, bool): - raise TypeError("Arguments to `substitute` cannot be bool.") - - new_exprs = [] - update = False - for operand in self.operands: - if isinstance(operand, Expr): - val = operand.substitute(old, new) - if operand._name != val._name: - update = True - new_exprs.append(val) - elif ( - isinstance(self, Fused) - and isinstance(operand, list) - and all(isinstance(op, Expr) for op in operand) - ): - # Special handling for `Fused`. - # We make no promise to dive through a - # list operand in general, but NEED to - # do so for the `Fused.exprs` operand. - val = [] - for op in operand: - val.append(op.substitute(old, new)) - if val[-1]._name != op._name: - update = True - new_exprs.append(val) - elif ( - substitute_literal - and not isinstance(operand, bool) - and isinstance(operand, type(old)) - and operand == old - ): - new_exprs.append(new) - update = True - else: - new_exprs.append(operand) - - if update: # Only recreate if something changed - return type(self)(*new_exprs) - return self - - def substitute_parameters(self, substitutions: dict) -> Expr: - """Substitute specific `Expr` parameters - - Parameters - ---------- - substitutions: - Mapping of parameter keys to new values. Keys that - are not found in ``self._parameters`` will be ignored. - """ - if not substitutions: - return self - - changed = False - new_operands = [] - for i, operand in enumerate(self.operands): - if i < len(self._parameters) and self._parameters[i] in substitutions: - new_operands.append(substitutions[self._parameters[i]]) - changed = True - else: - new_operands.append(operand) - if changed: - return type(self)(*new_operands) - return self - - def _find_similar_operations(self, root: Expr, ignore: list | None = None): - # Find operations with the same type and operands. - # Parameter keys specified by `ignore` will not be - # included in the operand comparison - alike = [ - op for op in root.find_operations(type(self)) if op._name != self._name - ] - if not alike: - # No other operations of the same type. Early return - return [] - - # Return subset of `alike` with the same "token" - token = _tokenize_partial(self, ignore) - return [item for item in alike if _tokenize_partial(item, ignore) == token] - - def _node_label_args(self): - """Operands to include in the node label by `visualize`""" - return self.dependencies() - - def _to_graphviz( - self, - rankdir="BT", - graph_attr=None, - node_attr=None, - edge_attr=None, - **kwargs, - ): - from dask.dot import label, name - - graphviz = import_required( - "graphviz", - "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` " - "python library and the `graphviz` system library.\n\n" - "Please either conda or pip install as follows:\n\n" - " conda install python-graphviz # either conda install\n" - " python -m pip install graphviz # or pip install and follow installation instructions", - ) - - graph_attr = graph_attr or {} - node_attr = node_attr or {} - edge_attr = edge_attr or {} - - graph_attr["rankdir"] = rankdir - node_attr["shape"] = "box" - node_attr["fontname"] = "helvetica" - - graph_attr.update(kwargs) - g = graphviz.Digraph( - graph_attr=graph_attr, - node_attr=node_attr, - edge_attr=edge_attr, - ) - - stack = [self] - seen = set() - dependencies = {} - while stack: - expr = stack.pop() - - if expr._name in seen: - continue - seen.add(expr._name) - - dependencies[expr] = set(expr.dependencies()) - for dep in expr.dependencies(): - stack.append(dep) - - cache = {} - for expr in dependencies: - expr_name = name(expr) - attrs = {} - - # Make node label - deps = [ - funcname(type(dep)) if isinstance(dep, Expr) else str(dep) - for dep in expr._node_label_args() - ] - _label = funcname(type(expr)) - if deps: - _label = f"{_label}({', '.join(deps)})" if deps else _label - node_label = label(_label, cache=cache) - - attrs.setdefault("label", str(node_label)) - attrs.setdefault("fontsize", "20") - g.node(expr_name, **attrs) - - for expr, deps in dependencies.items(): - expr_name = name(expr) - for dep in deps: - dep_name = name(dep) - g.edge(dep_name, expr_name) - - return g - - def visualize(self, filename="dask-expr.svg", format=None, **kwargs): - """ - Visualize the expression graph. - Requires ``graphviz`` to be installed. - - Parameters - ---------- - filename : str or None, optional - The name of the file to write to disk. If the provided `filename` - doesn't include an extension, '.png' will be used by default. - If `filename` is None, no file will be written, and the graph is - rendered in the Jupyter notebook only. - format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional - Format in which to write output file. Default is 'svg'. - **kwargs - Additional keyword arguments to forward to ``to_graphviz``. - """ - from dask.dot import graphviz_to_file - - g = self._to_graphviz(**kwargs) - graphviz_to_file(g, filename, format) - return g - - def walk(self) -> Generator[Expr]: - """Iterate through all expressions in the tree - - Returns - ------- - nodes - Generator of Expr instances in the graph. - Ordering is a depth-first search of the expression tree - """ - stack = [self] - seen = set() - while stack: - node = stack.pop() - if node._name in seen: - continue - seen.add(node._name) - - for dep in node.dependencies(): - stack.append(dep) - - yield node - - def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]: - """Search the expression graph for a specific operation type - - Parameters - ---------- - operation - The operation type to search for. - - Returns - ------- - nodes - Generator of `operation` instances. Ordering corresponds - to a depth-first search of the expression graph. - """ - assert ( - isinstance(operation, tuple) - and all(issubclass(e, Expr) for e in operation) - or issubclass(operation, Expr) - ), "`operation` must be`Expr` subclass)" - return (expr for expr in self.walk() if isinstance(expr, operation)) - class Literal(Expr): """Represent a literal (known) value as an `Expr`""" From a9b806a7496e7ee8575ce9ed4b3df3447973ba98 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 5 Dec 2023 16:59:49 -0800 Subject: [PATCH 02/19] Clean up a few things while implementing arrays --- dask_expr/_core.py | 32 ++++++++++++++++++++++++++++++-- dask_expr/_expr.py | 3 +++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 6b2f6f23..08d21e88 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -368,6 +368,33 @@ def _name(self): def _meta(self): raise NotImplementedError() + def __getattr__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError as err: + if key == "_meta": + # Avoid a recursive loop if/when `self._meta` + # produces an `AttributeError` + raise RuntimeError( + f"Failed to generate metadata for {self}. " + "This operation may not be supported by the current backend." + ) + + # Allow operands to be accessed as attributes + # as long as the keys are not already reserved + # by existing methods/properties + _parameters = type(self)._parameters + if key in _parameters: + idx = _parameters.index(key) + return self.operands[idx] + + link = "https://github.com/dask-contrib/dask-expr/blob/main/README.md#api-coverage" + raise AttributeError( + f"{err}\n\n" + "This often means that you are attempting to use an unsupported " + f"API function. Current API coverage is documented here: {link}." + ) + def __dask_graph__(self): """Traverse expression tree, collect layers""" stack = [self] @@ -386,8 +413,9 @@ def __dask_graph__(self): return toolz.merge(layers) - def __dask_keys__(self): - return [(self._name, i) for i in range(self.npartitions)] + @property + def dask(self): + return self.__dask_graph__() def substitute(self, old, new) -> Expr: """Substitute a specific term within the expression diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 7e84c584..e73e7dad 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -50,6 +50,9 @@ def ndim(self): except AttributeError: return 0 + def __dask_keys__(self): + return [(self._name, i) for i in range(self.npartitions)] + def optimize(self, **kwargs): return optimize(self, **kwargs) From 5c313aed15d4c2e014e13e410c302f70d41215cb Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 5 Dec 2023 17:01:15 -0800 Subject: [PATCH 03/19] Add trivial Array implementation --- dask_expr/array.py | 124 ++++++++++++++++++++++++++++++++++ dask_expr/tests/test_array.py | 14 ++++ 2 files changed, 138 insertions(+) create mode 100644 dask_expr/array.py create mode 100644 dask_expr/tests/test_array.py diff --git a/dask_expr/array.py b/dask_expr/array.py new file mode 100644 index 00000000..b9afa612 --- /dev/null +++ b/dask_expr/array.py @@ -0,0 +1,124 @@ +import operator +from typing import Union + +import dask.array as da +from dask.base import DaskMethodsMixin, named_schedulers +from dask.utils import cached_cumsum, cached_property +from toolz import reduce + +from dask_expr import _core as core + +T_IntOrNaN = Union[int, float] # Should be Union[int, Literal[np.nan]] + + +class Array(core.Expr, DaskMethodsMixin): + _cached_keys = None + + __dask_scheduler__ = staticmethod( + named_schedulers.get("threads", named_schedulers["sync"]) + ) + __dask_optimize__ = staticmethod(lambda dsk, keys, **kwargs: dsk) + + def __dask_postcompute__(self): + return da.core.finalize, () + + def __dask_postpersist__(self): + return FromGraph, (self._meta, self.chunks, self._name) + + def __array_ufunc__(self, numpy_ufunc, method, *inputs, **kwargs): + raise NotImplementedError() + + @property + def shape(self): + return tuple(sum(c) for c in self.chunks) + + @cached_property + def shape(self) -> tuple[T_IntOrNaN, ...]: + return tuple(cached_cumsum(c, initial_zero=True)[-1] for c in self.chunks) + + @property + def chunksize(self) -> tuple[T_IntOrNaN, ...]: + return tuple(max(c) for c in self.chunks) + + @property + def dtype(self): + if isinstance(self._meta, tuple): + dtype = self._meta[0].dtype + else: + dtype = self._meta.dtype + return dtype + + def __dask_keys__(self): + if self._cached_keys is not None: + return self._cached_keys + + name, chunks, numblocks = self.name, self.chunks, self.numblocks + + def keys(*args): + if not chunks: + return [(name,)] + ind = len(args) + if ind + 1 == len(numblocks): + result = [(name,) + args + (i,) for i in range(numblocks[ind])] + else: + result = [keys(*(args + (i,))) for i in range(numblocks[ind])] + return result + + self._cached_keys = result = keys() + return result + + @cached_property + def numblocks(self): + return tuple(map(len, self.chunks)) + + @cached_property + def npartitions(self): + return reduce(operator.mul, self.numblocks, 1) + + @property + def name(self): + return self._name + + def __hash__(self): + return hash(self._name) + + +class FromArray(Array): + _parameters = ["array", "chunks"] + + @property + def _meta(self): + return self.array[tuple(slice(0, 0) for _ in range(self.array.ndim))] + + def _layer(self): + dsk = da.core.graph_from_arraylike( + self.array, chunks=self.chunks, shape=self.array.shape, name=self._name + ) + return dict(dsk) # this comes as a legacy HLG for now + + def __str__(self): + return "FromArray(...)" + + +class FromGraph(Array): + _parameters = ["layer", "_meta", "chunks", "_name"] + + @property + def _meta(self): + return self.operand("_meta") + + @property + def chunks(self): + return self.operand("chunks") + + @property + def _name(self): + return self.operand("_name") + + def _layer(self): + return dict(self.operand("layer")) + + +def from_array(x, chunks="auto"): + chunks = da.core.normalize_chunks(chunks, x.shape, dtype=x.dtype) + return FromArray(x, chunks) diff --git a/dask_expr/tests/test_array.py b/dask_expr/tests/test_array.py new file mode 100644 index 00000000..9664c9ea --- /dev/null +++ b/dask_expr/tests/test_array.py @@ -0,0 +1,14 @@ +import numpy as np +from dask.array.utils import assert_eq + +import dask_expr.array as da + + +def test_basic(): + x = np.random.random((10, 10)) + xx = da.from_array(x, chunks=(4, 4)) + xx._meta + xx.chunks + repr(xx) + + assert_eq(x, xx) From f2d023405ff6402be63cfe880b65078bb369c031 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 7 Dec 2023 18:11:23 -0800 Subject: [PATCH 04/19] Add basic rechunk, also move to a directory --- dask_expr/_core.py | 3 + dask_expr/array/__init__.py | 1 + dask_expr/{array.py => array/core.py} | 20 ++- dask_expr/array/rechunk.py | 179 ++++++++++++++++++++++ dask_expr/{ => array}/tests/test_array.py | 12 ++ 5 files changed, 211 insertions(+), 4 deletions(-) create mode 100644 dask_expr/array/__init__.py rename dask_expr/{array.py => array/core.py} (90%) create mode 100644 dask_expr/array/rechunk.py rename dask_expr/{ => array}/tests/test_array.py (50%) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 08d21e88..63de7ad6 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -100,6 +100,9 @@ def pprint(self): def __hash__(self): return hash(self._name) + def __dask_tokenize__(self): + return self._name + def __reduce__(self): if dask.config.get("dask-expr-no-serialize", False): raise RuntimeError(f"Serializing a {type(self)} object") diff --git a/dask_expr/array/__init__.py b/dask_expr/array/__init__.py new file mode 100644 index 00000000..5c04d131 --- /dev/null +++ b/dask_expr/array/__init__.py @@ -0,0 +1 @@ +from dask_expr.array.core import Array, from_array diff --git a/dask_expr/array.py b/dask_expr/array/core.py similarity index 90% rename from dask_expr/array.py rename to dask_expr/array/core.py index b9afa612..1c407341 100644 --- a/dask_expr/array.py +++ b/dask_expr/array/core.py @@ -28,14 +28,14 @@ def __dask_postpersist__(self): def __array_ufunc__(self, numpy_ufunc, method, *inputs, **kwargs): raise NotImplementedError() - @property - def shape(self): - return tuple(sum(c) for c in self.chunks) - @cached_property def shape(self) -> tuple[T_IntOrNaN, ...]: return tuple(cached_cumsum(c, initial_zero=True)[-1] for c in self.chunks) + @property + def ndim(self): + return len(self.shape) + @property def chunksize(self) -> tuple[T_IntOrNaN, ...]: return tuple(max(c) for c in self.chunks) @@ -82,6 +82,18 @@ def name(self): def __hash__(self): return hash(self._name) + def rechunk( + self, + chunks="auto", + threshold=None, + block_size_limit=None, + balance=False, + method=None, + ): + from dask_expr.array.rechunk import Rechunk + + return Rechunk(self, chunks, threshold, block_size_limit, balance, method) + class FromArray(Array): _parameters = ["array", "chunks"] diff --git a/dask_expr/array/rechunk.py b/dask_expr/array/rechunk.py new file mode 100644 index 00000000..19bf42c7 --- /dev/null +++ b/dask_expr/array/rechunk.py @@ -0,0 +1,179 @@ +import itertools +import operator + +import dask +import numpy as np +import toolz +from dask.array.core import concatenate3 +from dask.array.rechunk import ( + _balance_chunksizes, + _validate_rechunk, + intersect_chunks, + normalize_chunks, + plan_rechunk, + tokenize, + validate_axis, +) +from dask.utils import cached_property + +from dask_expr.array import Array + + +class Rechunk(Array): + _parameters = [ + "array", + "_chunks", + "threshold", + "block_size_limit", + "balance", + "method", + ] + + _defaults = { + "_chunks": "auto", + "threshold": None, + "block_size_limit": None, + "balance": None, + "method": None, + } + + @property + def _meta(self): + return self.array._meta + + @property + def _name(self): + return "rechunk-merge-" + tokenize(*self.operands) + + @cached_property + def chunks(self): + x = self.array + chunks = self.operand("_chunks") + + # don't rechunk if array is empty + if x.ndim > 0 and all(s == 0 for s in x.shape): + return x.chunks + + if isinstance(chunks, dict): + chunks = {validate_axis(c, x.ndim): v for c, v in chunks.items()} + for i in range(x.ndim): + if i not in chunks: + chunks[i] = x.chunks[i] + elif chunks[i] is None: + chunks[i] = x.chunks[i] + if isinstance(chunks, (tuple, list)): + chunks = tuple( + lc if lc is not None else rc for lc, rc in zip(chunks, x.chunks) + ) + chunks = normalize_chunks( + chunks, + x.shape, + limit=self.block_size_limit, + dtype=x.dtype, + previous_chunks=x.chunks, + ) + + if not len(chunks) == x.ndim: + raise ValueError("Provided chunks are not consistent with shape") + + if self.balance: + chunks = tuple(_balance_chunksizes(chunk) for chunk in chunks) + + _validate_rechunk(x.chunks, chunks) + + return chunks + + def _layer(self): + method = self.method or dask.config.get("array.rechunk.method") + if method == "tasks": + steps = plan_rechunk( + self.array.chunks, + self.chunks, + self.array.dtype.itemsize, + self.threshold, + self.block_size_limit, + ) + name = self.array.name + old_chunks = self.array.chunks + layers = [] + for i, c in enumerate(steps): + level = len(steps) - i - 1 + name, old_chunks, layer = _compute_rechunk( + name, old_chunks, c, level, self.name + ) + layers.append(layer) + + return toolz.merge(*layers) + + if method == "p2p": + raise NotImplementedError( + "This shouldn't be hard, but I haven't done it yet, things are in motion over there" + ) + + +def _compute_rechunk(old_name, old_chunks, chunks, level, name): + """Compute the rechunk of *x* to the given *chunks*.""" + # TODO: redo this logic + # if x.size == 0: + # # Special case for empty array, as the algorithm below does not behave correctly + # return empty(x.shape, chunks=chunks, dtype=x.dtype) + + ndim = len(old_chunks) + crossed = intersect_chunks(old_chunks, chunks) + x2 = dict() + intermediates = dict() + # token = tokenize(old_name, chunks) + if level != 0: + merge_name = name.replace("rechunk-merge-", f"rechunk-merge-{level}-") + split_name = name.replace("rechunk-merge-", f"rechunk-split-{level}-") + else: + merge_name = name.replace("rechunk-merge-", "rechunk-merge-") + split_name = name.replace("rechunk-merge-", "rechunk-split-") + split_name_suffixes = itertools.count() + + # Pre-allocate old block references, to allow re-use and reduce the + # graph's memory footprint a bit. + old_blocks = np.empty([len(c) for c in old_chunks], dtype="O") + for index in np.ndindex(old_blocks.shape): + old_blocks[index] = (old_name,) + index + + # Iterate over all new blocks + new_index = itertools.product(*(range(len(c)) for c in chunks)) + + for new_idx, cross1 in zip(new_index, crossed): + key = (merge_name,) + new_idx + old_block_indices = [[cr[i][0] for cr in cross1] for i in range(ndim)] + subdims1 = [len(set(old_block_indices[i])) for i in range(ndim)] + + rec_cat_arg = np.empty(subdims1, dtype="O") + rec_cat_arg_flat = rec_cat_arg.flat + + # Iterate over the old blocks required to build the new block + for rec_cat_index, ind_slices in enumerate(cross1): + old_block_index, slices = zip(*ind_slices) + name = (split_name, next(split_name_suffixes)) + old_index = old_blocks[old_block_index][1:] + if all( + slc.start == 0 and slc.stop == old_chunks[i][ind] + for i, (slc, ind) in enumerate(zip(slices, old_index)) + ): + rec_cat_arg_flat[rec_cat_index] = old_blocks[old_block_index] + else: + intermediates[name] = ( + operator.getitem, + old_blocks[old_block_index], + slices, + ) + rec_cat_arg_flat[rec_cat_index] = name + + assert rec_cat_index == rec_cat_arg.size - 1 + + # New block is formed by concatenation of sliced old blocks + if all(d == 1 for d in rec_cat_arg.shape): + x2[key] = rec_cat_arg.flat[0] + else: + x2[key] = (concatenate3, rec_cat_arg.tolist()) + + del old_blocks, new_index + + return name, chunks, {**x2, **intermediates} diff --git a/dask_expr/tests/test_array.py b/dask_expr/array/tests/test_array.py similarity index 50% rename from dask_expr/tests/test_array.py rename to dask_expr/array/tests/test_array.py index 9664c9ea..1515658a 100644 --- a/dask_expr/tests/test_array.py +++ b/dask_expr/array/tests/test_array.py @@ -12,3 +12,15 @@ def test_basic(): repr(xx) assert_eq(x, xx) + + +def test_rechunk(): + a = np.random.random((10, 10)) + b = da.from_array(a, chunks=(4, 4)) + c = b.rechunk() + assert c.npartitions == 1 + assert_eq(b, c) + + d = b.rechunk((3, 3)) + assert d.npartitions == 16 + assert_eq(d, a) From 288a1ba40a535aec33db92e5c217fb6a26e9229c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 8 Dec 2023 06:37:13 -0800 Subject: [PATCH 05/19] add Rechunk(Rechunk(...)) simplification --- dask_expr/array/core.py | 3 +++ dask_expr/array/rechunk.py | 5 +++++ dask_expr/array/tests/test_array.py | 10 ++++++++++ 3 files changed, 18 insertions(+) diff --git a/dask_expr/array/core.py b/dask_expr/array/core.py index 1c407341..af869823 100644 --- a/dask_expr/array/core.py +++ b/dask_expr/array/core.py @@ -82,6 +82,9 @@ def name(self): def __hash__(self): return hash(self._name) + def optimize(self): + return self.simplify() + def rechunk( self, chunks="auto", diff --git a/dask_expr/array/rechunk.py b/dask_expr/array/rechunk.py index 19bf42c7..e9548245 100644 --- a/dask_expr/array/rechunk.py +++ b/dask_expr/array/rechunk.py @@ -110,6 +110,11 @@ def _layer(self): "This shouldn't be hard, but I haven't done it yet, things are in motion over there" ) + def _simplify_down(self): + if isinstance(self.array, Rechunk): + # TODO: should maybe or the two balance values + return Rechunk(self.array.array, *self.operands[1:]) + def _compute_rechunk(old_name, old_chunks, chunks, level, name): """Compute the rechunk of *x* to the given *chunks*.""" diff --git a/dask_expr/array/tests/test_array.py b/dask_expr/array/tests/test_array.py index 1515658a..cc3c2ba3 100644 --- a/dask_expr/array/tests/test_array.py +++ b/dask_expr/array/tests/test_array.py @@ -24,3 +24,13 @@ def test_rechunk(): d = b.rechunk((3, 3)) assert d.npartitions == 16 assert_eq(d, a) + + +def test_rechunk_optimize(): + a = np.random.random((10, 10)) + b = da.from_array(a, chunks=(4, 4)) + + c = b.rechunk((2, 5)).rechunk((5, 2)) + d = b.rechunk((5, 2)) + + assert c.optimize()._name == d.optimize()._name From d360ca5ac03dd0b3402c409e32b2434262dad1e7 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 8 Dec 2023 10:10:14 -0800 Subject: [PATCH 06/19] Add basic blockwise functionality --- dask_expr/array/blockwise.py | 556 ++++++++++++++++++++++++++++ dask_expr/array/core.py | 15 + dask_expr/array/tests/test_array.py | 14 + 3 files changed, 585 insertions(+) create mode 100644 dask_expr/array/blockwise.py diff --git a/dask_expr/array/blockwise.py b/dask_expr/array/blockwise.py new file mode 100644 index 00000000..96443ee6 --- /dev/null +++ b/dask_expr/array/blockwise.py @@ -0,0 +1,556 @@ +import itertools +import numbers +from collections.abc import Iterable + +import numpy as np +import toolz +from dask.array.core import ( + _enforce_dtype, + apply_infer_dtype, + normalize_arg, + unify_chunks, +) +from dask.array.utils import compute_meta +from dask.base import is_dask_collection, tokenize +from dask.blockwise import blockwise as core_blockwise +from dask.delayed import unpack_collections +from dask.utils import funcname + +from dask_expr.array.core import Array + + +class Blockwise(Array): + _parameters = [ + "func", + "out_ind", + "name", + "token", + "dtype", + "adjust_chunks", + "new_axes", + "align_arrays", + "concatenate", + "_meta_provided", + "kwargs", + ] + _defaults = { + "name": None, + "token": None, + "dtype": None, + "adjust_chunks": None, + "new_axes": None, + "align_arrays": False, # TODO: this should be true, future work + "concatenate": None, + "_meta_provided": None, + "kwargs": None, + } + + @property + def args(self): + return self.operands[len(self._parameters) :] + + @property + def _meta(self): + if self._meta_provided is not None: + return self._meta_provided + else: + return compute_meta(self.func, self.dtype, *self.args[::2], **self.kwargs) + + @property + def chunks(self): + if self.align_arrays: + chunkss, arrays = unify_chunks(*self.args) + else: + arginds = [ + (a, i) for (a, i) in toolz.partition(2, self.args) if i is not None + ] + chunkss = {} + # For each dimension, use the input chunking that has the most blocks; + # this will ensure that broadcasting works as expected, and in + # particular the number of blocks should be correct if the inputs are + # consistent. + for arg, ind in arginds: + for c, i in zip(arg.chunks, ind): + if i not in chunkss or len(c) > len(chunkss[i]): + chunkss[i] = c + + for k, v in self.new_axes.items(): + if not isinstance(v, tuple): + v = (v,) + chunkss[k] = v + + chunks = [chunkss[i] for i in self.out_ind] + if self.adjust_chunks: + for i, ind in enumerate(self.out_ind): + if ind in self.adjust_chunks: + if callable(self.adjust_chunks[ind]): + chunks[i] = tuple(map(self.adjust_chunks[ind], chunks[i])) + elif isinstance(self.adjust_chunks[ind], numbers.Integral): + chunks[i] = tuple(self.adjust_chunks[ind] for _ in chunks[i]) + elif isinstance(self.adjust_chunks[ind], (tuple, list)): + if len(self.adjust_chunks[ind]) != len(chunks[i]): + raise ValueError( + f"Dimension {i} has {len(chunks[i])} blocks, adjust_chunks " + f"specified with {len(self.adjust_chunks[ind])} blocks" + ) + chunks[i] = tuple(self.adjust_chunks[ind]) + else: + raise NotImplementedError( + "adjust_chunks values must be callable, int, or tuple" + ) + chunks = tuple(chunks) + return chunks + + @property + def dtype(self): + return self.operand("dtype") + + @property + def _name(self): + if self.operand("name"): + return self.operand("name") + else: + return "{}-{}".format( + self.token or funcname(self.func).strip("_"), + tokenize( + self.func, self.out_ind, self.dtype, *self.args, **self.kwargs + ), + ) + + def _layer(self): + arginds = [(a, i) for (a, i) in toolz.partition(2, self.args)] + + numblocks = {} + dependencies = [] + arrays = [] + + # Normalize arguments + argindsstr = [] + + for arg, ind in arginds: + if ind is None: + arg = normalize_arg(arg) + arg, collections = unpack_collections(arg) + dependencies.extend(collections) + else: + if ( + hasattr(arg, "ndim") + and hasattr(ind, "__len__") + and arg.ndim != len(ind) + ): + raise ValueError( + "Index string %s does not match array dimension %d" + % (ind, arg.ndim) + ) + numblocks[arg.name] = arg.numblocks + arrays.append(arg) + arg = arg.name + argindsstr.extend((arg, ind)) + + # Normalize keyword arguments + kwargs2 = {} + for k, v in self.kwargs.items(): + v = normalize_arg(v) + v, collections = unpack_collections(v) + dependencies.extend(collections) + kwargs2[k] = v + + graph = core_blockwise( + self.func, + self._name, + self.out_ind, + *argindsstr, + numblocks=numblocks, + dependencies=dependencies, + new_axes=self.new_axes, + concatenate=self.concatenate, + **kwargs2, + ) + return dict(graph) + + +def blockwise( + func, + out_ind, + *args, + name=None, + token=None, + dtype=None, + adjust_chunks=None, + new_axes=None, + align_arrays=False, # TODO: this should be true, future work + concatenate=None, + meta=None, + cls=Blockwise, + **kwargs, +): + """Tensor operation: Generalized inner and outer products + + A broad class of blocked algorithms and patterns can be specified with a + concise multi-index notation. The ``blockwise`` function applies an in-memory + function across multiple blocks of multiple inputs in a variety of ways. + Many dask.array operations are special cases of blockwise including + elementwise, broadcasting, reductions, tensordot, and transpose. + + Parameters + ---------- + func : callable + Function to apply to individual tuples of blocks + out_ind : iterable + Block pattern of the output, something like 'ijk' or (1, 2, 3) + *args : sequence of Array, index pairs + You may also pass literal arguments, accompanied by None index + e.g. (x, 'ij', y, 'jk', z, 'i', some_literal, None) + **kwargs : dict + Extra keyword arguments to pass to function + dtype : np.dtype + Datatype of resulting array. + concatenate : bool, keyword only + If true concatenate arrays along dummy indices, else provide lists + adjust_chunks : dict + Dictionary mapping index to function to be applied to chunk sizes + new_axes : dict, keyword only + New indexes and their dimension lengths + align_arrays: bool + Whether or not to align chunks along equally sized dimensions when + multiple arrays are provided. This allows for larger chunks in some + arrays to be broken into smaller ones that match chunk sizes in other + arrays such that they are compatible for block function mapping. If + this is false, then an error will be thrown if arrays do not already + have the same number of blocks in each dimension. + + Examples + -------- + 2D embarrassingly parallel operation from two arrays, x, and y. + + >>> import operator, numpy as np, dask.array as da + >>> x = da.from_array([[1, 2], + ... [3, 4]], chunks=(1, 2)) + >>> y = da.from_array([[10, 20], + ... [0, 0]]) + >>> z = blockwise(operator.add, 'ij', x, 'ij', y, 'ij', dtype='f8') + >>> z.compute() + array([[11, 22], + [ 3, 4]]) + + Outer product multiplying a by b, two 1-d vectors + + >>> a = da.from_array([0, 1, 2], chunks=1) + >>> b = da.from_array([10, 50, 100], chunks=1) + >>> z = blockwise(np.outer, 'ij', a, 'i', b, 'j', dtype='f8') + >>> z.compute() + array([[ 0, 0, 0], + [ 10, 50, 100], + [ 20, 100, 200]]) + + z = x.T + + >>> z = blockwise(np.transpose, 'ji', x, 'ij', dtype=x.dtype) + >>> z.compute() + array([[1, 3], + [2, 4]]) + + The transpose case above is illustrative because it does transposition + both on each in-memory block by calling ``np.transpose`` and on the order + of the blocks themselves, by switching the order of the index ``ij -> ji``. + + We can compose these same patterns with more variables and more complex + in-memory functions + + z = X + Y.T + + >>> z = blockwise(lambda x, y: x + y.T, 'ij', x, 'ij', y, 'ji', dtype='f8') + >>> z.compute() + array([[11, 2], + [23, 4]]) + + Any index, like ``i`` missing from the output index is interpreted as a + contraction (note that this differs from Einstein convention; repeated + indices do not imply contraction.) In the case of a contraction the passed + function should expect an iterable of blocks on any array that holds that + index. To receive arrays concatenated along contracted dimensions instead + pass ``concatenate=True``. + + Inner product multiplying a by b, two 1-d vectors + + >>> def sequence_dot(a_blocks, b_blocks): + ... result = 0 + ... for a, b in zip(a_blocks, b_blocks): + ... result += a.dot(b) + ... return result + + >>> z = blockwise(sequence_dot, '', a, 'i', b, 'i', dtype='f8') + >>> z.compute() + 250 + + Add new single-chunk dimensions with the ``new_axes=`` keyword, including + the length of the new dimension. New dimensions will always be in a single + chunk. + + >>> def f(a): + ... return a[:, None] * np.ones((1, 5)) + + >>> z = blockwise(f, 'az', a, 'a', new_axes={'z': 5}, dtype=a.dtype) + + New dimensions can also be multi-chunk by specifying a tuple of chunk + sizes. This has limited utility as is (because the chunks are all the + same), but the resulting graph can be modified to achieve more useful + results (see ``da.map_blocks``). + + >>> z = blockwise(f, 'az', a, 'a', new_axes={'z': (5, 5)}, dtype=x.dtype) + >>> z.chunks + ((1, 1, 1), (5, 5)) + + If the applied function changes the size of each chunk you can specify this + with a ``adjust_chunks={...}`` dictionary holding a function for each index + that modifies the dimension size in that index. + + >>> def double(x): + ... return np.concatenate([x, x]) + + >>> y = blockwise(double, 'ij', x, 'ij', + ... adjust_chunks={'i': lambda n: 2 * n}, dtype=x.dtype) + >>> y.chunks + ((2, 2), (2,)) + + Include literals by indexing with None + + >>> z = blockwise(operator.add, 'ij', x, 'ij', 1234, None, dtype=x.dtype) + >>> z.compute() + array([[1235, 1236], + [1237, 1238]]) + """ + new_axes = new_axes or {} + + # Input Validation + if len(set(out_ind)) != len(out_ind): + raise ValueError( + "Repeated elements not allowed in output index", + [k for k, v in toolz.frequencies(out_ind).items() if v > 1], + ) + new = ( + set(out_ind) + - {a for arg in args[1::2] if arg is not None for a in arg} + - set(new_axes or ()) + ) + if new: + raise ValueError("Unknown dimension", new) + + assert not align_arrays # TODO, need unify_chunks + + return cls( + func, + out_ind, + name, + token, + dtype, + adjust_chunks, + new_axes, + align_arrays, # TODO: this should be true, future work + concatenate, + meta, + kwargs, + *args, + ) + + +class Elemwise(Blockwise): + pass + + +def elemwise(op, *args, out=None, where=True, dtype=None, name=None, **kwargs): + """Apply an elementwise ufunc-like function blockwise across arguments. + + Like numpy ufuncs, broadcasting rules are respected. + + Parameters + ---------- + op : callable + The function to apply. Should be numpy ufunc-like in the parameters + that it accepts. + *args : Any + Arguments to pass to `op`. Non-dask array-like objects are first + converted to dask arrays, then all arrays are broadcast together before + applying the function blockwise across all arguments. Any scalar + arguments are passed as-is following normal numpy ufunc behavior. + out : dask array, optional + If out is a dask.array then this overwrites the contents of that array + with the result. + where : array_like, optional + An optional boolean mask marking locations where the ufunc should be + applied. Can be a scalar, dask array, or any other array-like object. + Mirrors the ``where`` argument to numpy ufuncs, see e.g. ``numpy.add`` + for more information. + dtype : dtype, optional + If provided, overrides the output array dtype. + name : str, optional + A unique key name to use when building the backing dask graph. If not + provided, one will be automatically generated based on the input + arguments. + + Examples + -------- + >>> elemwise(add, x, y) # doctest: +SKIP + >>> elemwise(sin, x) # doctest: +SKIP + >>> elemwise(sin, x, out=dask_array) # doctest: +SKIP + + See Also + -------- + blockwise + """ + if kwargs: + raise TypeError( + f"{op.__name__} does not take the following keyword arguments " + f"{sorted(kwargs)}" + ) + + if out is not None: + raise NotImplementedError() + if where is not True: + raise NotImplementedError() + + args = [np.asarray(a) if isinstance(a, (list, tuple)) else a for a in args] + + shapes = [] + for arg in args: + shape = getattr(arg, "shape", ()) + if any(is_dask_collection(x) for x in shape): + # Want to exclude Delayed shapes and dd.Scalar + shape = () + shapes.append(shape) + if isinstance(where, Array): + shapes.append(where.shape) + if isinstance(out, Array): + shapes.append(out.shape) + + shapes = [s if isinstance(s, Iterable) else () for s in shapes] + out_ndim = len( + broadcast_shapes(*shapes) + ) # Raises ValueError if dimensions mismatch + expr_inds = tuple(range(out_ndim))[::-1] + + if dtype is not None: + need_enforce_dtype = True + else: + # We follow NumPy's rules for dtype promotion, which special cases + # scalars and 0d ndarrays (which it considers equivalent) by using + # their values to compute the result dtype: + # https://github.com/numpy/numpy/issues/6240 + # We don't inspect the values of 0d dask arrays, because these could + # hold potentially very expensive calculations. Instead, we treat + # them just like other arrays, and if necessary cast the result of op + # to match. + vals = [ + np.empty((1,) * max(1, a.ndim), dtype=a.dtype) + if not is_scalar_for_elemwise(a) + else a + for a in args + ] + try: + dtype = apply_infer_dtype(op, vals, {}, "elemwise", suggest_dtype=False) + except Exception: + return NotImplemented + need_enforce_dtype = any( + not is_scalar_for_elemwise(a) and a.ndim == 0 for a in args + ) + + # if not name: + # name = f"{funcname(op)}-{tokenize(op, dtype, *args, where)}" + + blockwise_kwargs = dict(dtype=dtype, token=funcname(op).strip("_")) + + # TODO: add back + # if where is not True: + # blockwise_kwargs["elemwise_where_function"] = op + # op = _elemwise_handle_where + # args.extend([where, out]) + + if need_enforce_dtype: + blockwise_kwargs["enforce_dtype"] = dtype + blockwise_kwargs["enforce_dtype_function"] = op + op = _enforce_dtype + + result = blockwise( + op, + expr_inds, + *toolz.concat( + (a, tuple(range(a.ndim)[::-1]) if not is_scalar_for_elemwise(a) else None) + for a in args + ), + cls=Elemwise, + **blockwise_kwargs, + ) + + # TODO: handle out + # return handle_out(out, result) + return result + + +def broadcast_shapes(*shapes): + """ + Determines output shape from broadcasting arrays. + + Parameters + ---------- + shapes : tuples + The shapes of the arguments. + + Returns + ------- + output_shape : tuple + + Raises + ------ + ValueError + If the input shapes cannot be successfully broadcast together. + """ + if len(shapes) == 1: + return shapes[0] + out = [] + for sizes in itertools.zip_longest(*map(reversed, shapes), fillvalue=-1): + if np.isnan(sizes).any(): + dim = np.nan + else: + dim = 0 if 0 in sizes else np.max(sizes) + if any(i not in [-1, 0, 1, dim] and not np.isnan(i) for i in sizes): + raise ValueError( + "operands could not be broadcast together with " + "shapes {}".format(" ".join(map(str, shapes))) + ) + out.append(dim) + return tuple(reversed(out)) + + +def is_scalar_for_elemwise(arg): + """ + + >>> is_scalar_for_elemwise(42) + True + >>> is_scalar_for_elemwise('foo') + True + >>> is_scalar_for_elemwise(True) + True + >>> is_scalar_for_elemwise(np.array(42)) + True + >>> is_scalar_for_elemwise([1, 2, 3]) + True + >>> is_scalar_for_elemwise(np.array([1, 2, 3])) + False + >>> is_scalar_for_elemwise(from_array(np.array(0), chunks=())) + False + >>> is_scalar_for_elemwise(np.dtype('i4')) + True + """ + # the second half of shape_condition is essentially just to ensure that + # dask series / frame are treated as scalars in elemwise. + maybe_shape = getattr(arg, "shape", None) + shape_condition = not isinstance(maybe_shape, Iterable) or any( + is_dask_collection(x) for x in maybe_shape + ) + + return ( + np.isscalar(arg) + or shape_condition + or isinstance(arg, np.dtype) + or (isinstance(arg, np.ndarray) and arg.ndim == 0) + ) diff --git a/dask_expr/array/core.py b/dask_expr/array/core.py index af869823..56ef38fe 100644 --- a/dask_expr/array/core.py +++ b/dask_expr/array/core.py @@ -97,6 +97,18 @@ def rechunk( return Rechunk(self, chunks, threshold, block_size_limit, balance, method) + def __add__(self, other): + return elemwise(operator.add, self, other) + + def __radd__(self, other): + return elemwise(operator.add, other, self) + + def __mul__(self, other): + return elemwise(operator.add, self, other) + + def __rmul__(self, other): + return elemwise(operator.mul, other, self) + class FromArray(Array): _parameters = ["array", "chunks"] @@ -137,3 +149,6 @@ def _layer(self): def from_array(x, chunks="auto"): chunks = da.core.normalize_chunks(chunks, x.shape, dtype=x.dtype) return FromArray(x, chunks) + + +from dask_expr.array.blockwise import elemwise diff --git a/dask_expr/array/tests/test_array.py b/dask_expr/array/tests/test_array.py index cc3c2ba3..ab7b2822 100644 --- a/dask_expr/array/tests/test_array.py +++ b/dask_expr/array/tests/test_array.py @@ -34,3 +34,17 @@ def test_rechunk_optimize(): d = b.rechunk((5, 2)) assert c.optimize()._name == d.optimize()._name + + +def test_elemwise(): + a = np.random.random((10, 10)) + b = da.from_array(a, chunks=(4, 4)) + + (b + 1).compute() + assert_eq(a + 1, b + 1) + assert_eq(a + 2 * a, b + 2 * b) + + x = np.random.random(10) + y = da.from_array(x, chunks=(4,)) + + assert_eq(a + x, b + y) From e6385dd04ea2f1f6fbedcd73002233497f525e01 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 8 Dec 2023 10:30:43 -0800 Subject: [PATCH 07/19] add transpose --- dask_expr/array/blockwise.py | 39 ++++++++++++++++++++++++++++- dask_expr/array/core.py | 16 +++++++++++- dask_expr/array/tests/test_array.py | 12 +++++++++ 3 files changed, 65 insertions(+), 2 deletions(-) diff --git a/dask_expr/array/blockwise.py b/dask_expr/array/blockwise.py index 96443ee6..b0691b8c 100644 --- a/dask_expr/array/blockwise.py +++ b/dask_expr/array/blockwise.py @@ -107,7 +107,7 @@ def dtype(self): @property def _name(self): - if self.operand("name"): + if "name" in self._parameters and self.operand("name"): return self.operand("name") else: return "{}-{}".format( @@ -554,3 +554,40 @@ def is_scalar_for_elemwise(arg): or isinstance(arg, np.dtype) or (isinstance(arg, np.ndarray) and arg.ndim == 0) ) + + +class Transpose(Elemwise): + _parameters = ["array", "axes"] + func = staticmethod(np.transpose) + align_arrays = False + adjust_chunks = None + concatenate = None + token = "transpose" + + @property + def new_axes(self): + return {} + + @property + def name(self): + return self._name + + @property + def _meta_provided(self): + return self.array._meta + + @property + def dtype(self): + return self._meta.dtype + + @property + def out_ind(self): + return self.axes + + @property + def kwargs(self): + return {"axes": self.axes} + + @property + def args(self): + return (self.array, tuple(range(self.array.ndim))) diff --git a/dask_expr/array/core.py b/dask_expr/array/core.py index 56ef38fe..68cd40e8 100644 --- a/dask_expr/array/core.py +++ b/dask_expr/array/core.py @@ -97,6 +97,20 @@ def rechunk( return Rechunk(self, chunks, threshold, block_size_limit, balance, method) + def transpose(self, axes=None): + if axes: + if len(axes) != self.ndim: + raise ValueError("axes don't match array") + axes = tuple(d + self.ndim if d < 0 else d for d in axes) + else: + axes = tuple(range(self.ndim))[::-1] + + return Transpose(self, axes) + + @property + def T(self): + return self.transpose() + def __add__(self, other): return elemwise(operator.add, self, other) @@ -151,4 +165,4 @@ def from_array(x, chunks="auto"): return FromArray(x, chunks) -from dask_expr.array.blockwise import elemwise +from dask_expr.array.blockwise import Transpose, elemwise diff --git a/dask_expr/array/tests/test_array.py b/dask_expr/array/tests/test_array.py index ab7b2822..509acd04 100644 --- a/dask_expr/array/tests/test_array.py +++ b/dask_expr/array/tests/test_array.py @@ -48,3 +48,15 @@ def test_elemwise(): y = da.from_array(x, chunks=(4,)) assert_eq(a + x, b + y) + + +def test_transpose(): + a = np.random.random((10, 20)) + b = da.from_array(a, chunks=(2, 5)) + + assert_eq(a.T, b.T) + + a = np.random.random((10, 1)) + b = da.from_array(a, chunks=(5, 1)) + assert_eq(a.T + a, b.T + b) + assert_eq(a + a.T, b + b.T) From b9a51e9c8359a00d506cc830c3844f4dffad7590 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 8 Dec 2023 13:25:57 -0800 Subject: [PATCH 08/19] Simplify blockwise + rechunk + IO --- dask_expr/array/blockwise.py | 7 ++++++ dask_expr/array/core.py | 13 ++++++++++-- dask_expr/array/rechunk.py | 32 ++++++++++++++++++++++++++++ dask_expr/array/tests/test_array.py | 33 +++++++++++++++++++++++++++++ 4 files changed, 83 insertions(+), 2 deletions(-) diff --git a/dask_expr/array/blockwise.py b/dask_expr/array/blockwise.py index b0691b8c..96f5e296 100644 --- a/dask_expr/array/blockwise.py +++ b/dask_expr/array/blockwise.py @@ -591,3 +591,10 @@ def kwargs(self): @property def args(self): return (self.array, tuple(range(self.array.ndim))) + + def _simplify_down(self): + if isinstance(self.array, Transpose): + axes = tuple(self.array.axes[i] for i in self.axes) + return Transpose(self.array.array, axes) + if self.axes == tuple(range(self.ndim)): + return self.array diff --git a/dask_expr/array/core.py b/dask_expr/array/core.py index 68cd40e8..b02b0f89 100644 --- a/dask_expr/array/core.py +++ b/dask_expr/array/core.py @@ -124,9 +124,19 @@ def __rmul__(self, other): return elemwise(operator.mul, other, self) -class FromArray(Array): +class IO(Array): + pass + + +class FromArray(IO): _parameters = ["array", "chunks"] + @property + def chunks(self): + return da.core.normalize_chunks( + self.operand("chunks"), self.array.shape, dtype=self.array.dtype + ) + @property def _meta(self): return self.array[tuple(slice(0, 0) for _ in range(self.array.ndim))] @@ -161,7 +171,6 @@ def _layer(self): def from_array(x, chunks="auto"): - chunks = da.core.normalize_chunks(chunks, x.shape, dtype=x.dtype) return FromArray(x, chunks) diff --git a/dask_expr/array/rechunk.py b/dask_expr/array/rechunk.py index e9548245..a7f790de 100644 --- a/dask_expr/array/rechunk.py +++ b/dask_expr/array/rechunk.py @@ -1,4 +1,5 @@ import itertools +import numbers import operator import dask @@ -17,6 +18,7 @@ from dask.utils import cached_property from dask_expr.array import Array +from dask_expr.array.core import IO class Rechunk(Array): @@ -114,6 +116,33 @@ def _simplify_down(self): if isinstance(self.array, Rechunk): # TODO: should maybe or the two balance values return Rechunk(self.array.array, *self.operands[1:]) + if isinstance(self.array, Elemwise): + if isinstance(self._chunks, (str, numbers.Number)): + return self.array.substitute( + self.array, + self.array.rechunk(self._chunks), + ) + # TODO: handle subclasses + if type(self.array) == Elemwise and isinstance(self._chunks, tuple): + args = [] + for arg, inds in toolz.partition_all(2, self.array.args): + if inds is None: + args.extend((arg, inds)) + else: + assert isinstance(arg, Array) + idx = tuple(self.array.out_ind.index(i) for i in inds) + chunks = tuple([self._chunks[i] for i in idx]) + arg = arg.rechunk(chunks) + args.extend((arg, inds)) + + return Elemwise(*self.array.operands[: -len(args)], *args) + + if isinstance(self.array, IO) and "chunks" in self.array._parameters: + chunks = tuple( + c if n != 1 else 1 if isinstance(c, numbers.Number) else (1,) + for n, c in zip(self.array.shape, self._chunks) + ) + return self.array.substitute_parameters({"chunks": chunks}) def _compute_rechunk(old_name, old_chunks, chunks, level, name): @@ -182,3 +211,6 @@ def _compute_rechunk(old_name, old_chunks, chunks, level, name): del old_blocks, new_index return name, chunks, {**x2, **intermediates} + + +from dask_expr.array.blockwise import Elemwise diff --git a/dask_expr/array/tests/test_array.py b/dask_expr/array/tests/test_array.py index 509acd04..ccc6c0a9 100644 --- a/dask_expr/array/tests/test_array.py +++ b/dask_expr/array/tests/test_array.py @@ -36,6 +36,37 @@ def test_rechunk_optimize(): assert c.optimize()._name == d.optimize()._name +def test_rechunk_blockwise_optimize(): + a = np.random.random((10, 10)) + b = da.from_array(a, chunks=(4, 4)) + + result = (da.from_array(a, chunks=(4, 4)) + 1).rechunk((5, 5)) + expected = da.from_array(a, chunks=(5, 5)) + 1 + assert result.optimize()._name == expected.optimize()._name + + a = np.random.random((10,)) + aa = da.from_array(a) + b = np.random.random((10, 10)) + bb = da.from_array(b) + + c = (aa + bb).rechunk((5, 2)) + result = c.optimize() + + expected = da.from_array(a, chunks=(2,)) + da.from_array(b, chunks=(5, 2)) + assert result._name == expected._name + + a = np.random.random((10, 1)) + aa = da.from_array(a) + b = np.random.random((10, 10)) + bb = da.from_array(b) + + c = (aa + bb).rechunk((5, 2)) + result = c.optimize() + + expected = da.from_array(a, chunks=(5, 1)) + da.from_array(b, chunks=(5, 2)) + assert result._name == expected._name + + def test_elemwise(): a = np.random.random((10, 10)) b = da.from_array(a, chunks=(4, 4)) @@ -60,3 +91,5 @@ def test_transpose(): b = da.from_array(a, chunks=(5, 1)) assert_eq(a.T + a, b.T + b) assert_eq(a + a.T, b + b.T) + + assert b.T.T.optimize()._name == b.optimize()._name From 28fe54fb487fe0084d17b818783015e981d79d10 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 8 Dec 2023 17:26:38 -0800 Subject: [PATCH 09/19] Support xarray rechunking --- dask_expr/array/core.py | 3 +++ dask_expr/array/rechunk.py | 23 ++++++++++++++++------- dask_expr/array/tests/test_array.py | 12 +++++++++++- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/dask_expr/array/core.py b/dask_expr/array/core.py index 08fff249..7cda84c1 100644 --- a/dask_expr/array/core.py +++ b/dask_expr/array/core.py @@ -28,6 +28,9 @@ def __dask_postpersist__(self): def __array_ufunc__(self, numpy_ufunc, method, *inputs, **kwargs): raise NotImplementedError() + def __array_function__(self, *args, **kwargs): + raise NotImplementedError() + def __getitem__(self, index): from dask.array.slicing import normalize_index diff --git a/dask_expr/array/rechunk.py b/dask_expr/array/rechunk.py index 3b4c8510..48c91bf9 100644 --- a/dask_expr/array/rechunk.py +++ b/dask_expr/array/rechunk.py @@ -126,25 +126,34 @@ def _simplify_down(self): # TODO: this probably doesn't support contractions or expansions # We should probably just abort in those cases for now (or do # chunksize math) - if type(self.array) == Elemwise and isinstance(self._chunks, tuple): + if type(self.array) == Elemwise and isinstance(self._chunks, (dict, tuple)): args = [] for arg, inds in toolz.partition_all(2, self.array.args): if inds is None: args.extend((arg, inds)) else: assert isinstance(arg, Array) - idx = tuple(self.array.out_ind.index(i) for i in inds) - chunks = tuple([self._chunks[i] for i in idx]) + if isinstance(self._chunks, tuple): + idx = tuple(self.array.out_ind.index(i) for i in inds) + chunks = tuple([self._chunks[i] for i in idx]) + elif isinstance(self._chunks, dict): + chunks = { + i: self._chunks[j] + for i, j in zip(self.array.out_ind, inds) + if j in self._chunks + } arg = arg.rechunk(chunks) args.extend((arg, inds)) return Elemwise(*self.array.operands[: -len(args)], *args) if isinstance(self.array, IO) and "chunks" in self.array._parameters: - chunks = tuple( - c if n != 1 else 1 if isinstance(c, numbers.Number) else (1,) - for n, c in zip(self.array.shape, self._chunks) - ) + chunks = self._chunks + if isinstance(chunks, tuple): + chunks = tuple( + c if n != 1 else 1 if isinstance(c, numbers.Number) else (1,) + for n, c in zip(self.array.shape, self._chunks) + ) return self.array.substitute_parameters({"chunks": chunks}) diff --git a/dask_expr/array/tests/test_array.py b/dask_expr/array/tests/test_array.py index 9bab5455..708e64ff 100644 --- a/dask_expr/array/tests/test_array.py +++ b/dask_expr/array/tests/test_array.py @@ -51,7 +51,6 @@ def test_rechunk_blockwise_optimize(): c = (aa + bb).rechunk((5, 2)) result = c.optimize() - expected = da.from_array(a, chunks=(2,)) + da.from_array(b, chunks=(5, 2)) assert result._name == expected._name @@ -112,3 +111,14 @@ def test_slicing_optimization(): assert b[:].optimize()._name == b._name assert b[5:, 4][::2].optimize()._name == b[5::2, 4].optimize()._name + + +def test_xarray(): + import xarray as xr + + a = np.random.random((10, 20)) + b = da.from_array(a) + + x = (xr.DataArray(b, dims=["x", "y"]) + 1).chunk(x=2) + + assert x.data.optimize()._name == (da.from_array(a, chunks={0: 2}) + 1)._name From 673c28f8ec10796f9d34770967cbba1c1f9356bb Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 8 Dec 2023 17:51:01 -0800 Subject: [PATCH 10/19] Some slicing elemwise optimizations --- dask_expr/array/slicing.py | 17 +++++++++++++++++ dask_expr/array/tests/test_array.py | 10 ++++++++++ 2 files changed, 27 insertions(+) diff --git a/dask_expr/array/slicing.py b/dask_expr/array/slicing.py index 81c51639..1f54e718 100644 --- a/dask_expr/array/slicing.py +++ b/dask_expr/array/slicing.py @@ -1,3 +1,4 @@ +import toolz from dask.array.optimization import fuse_slice from dask.array.slicing import normalize_slice, slice_array from dask.array.utils import meta_from_array @@ -40,3 +41,19 @@ def _simplify_down(self): fuse_slice(self.array.index, self.index), self.array.array.ndim ), ) + + if isinstance(self.array, Elemwise) and all( + isinstance(i, (slice, list)) for i in self.index + ): # TODO: deal with change in dimensionality + index = self.index + (slice(None),) * (self.ndim - len(self.index)) + args = [] + for arg, ind in toolz.partition(2, self.array.args): + if ind is None: + args.extend((arg, ind)) + else: + idx = tuple(index[self.array.out_ind.index(i)] for i in ind) + args.extend((arg[idx], ind)) + return Elemwise(*self.array.operands[: -len(args)], *args) + + +from dask_expr.array.blockwise import Elemwise diff --git a/dask_expr/array/tests/test_array.py b/dask_expr/array/tests/test_array.py index 708e64ff..022a85ab 100644 --- a/dask_expr/array/tests/test_array.py +++ b/dask_expr/array/tests/test_array.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from dask.array.utils import assert_eq import dask_expr.array as da @@ -112,6 +113,15 @@ def test_slicing_optimization(): assert b[:].optimize()._name == b._name assert b[5:, 4][::2].optimize()._name == b[5::2, 4].optimize()._name + assert (b + 1)[:5].optimize()._name == (b[:5] + 1)._name + + +@pytest.mark.xfail(reason="Blockwise specifies too much about dimension") +def test_slicing_optimization_change_dimensionality(): + a = np.random.random((10, 20)) + b = da.from_array(a, chunks=(2, 5)) + assert (b + 1)[5].optimize()._name == (b[5] + 1)._name + def test_xarray(): import xarray as xr From e37cd20c0814969b9541300c166b93bbc0bf68a8 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 8 Dec 2023 18:03:12 -0800 Subject: [PATCH 11/19] add comment around elemwise future --- dask_expr/array/blockwise.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dask_expr/array/blockwise.py b/dask_expr/array/blockwise.py index 96f5e296..d3c65edd 100644 --- a/dask_expr/array/blockwise.py +++ b/dask_expr/array/blockwise.py @@ -355,6 +355,8 @@ def blockwise( class Elemwise(Blockwise): + # TODO: we should make this have simpler parameters. Optimizations like + # slicing will be more powerful I think. pass From 67635c0ac6ec98b77282bd263817f0f3eeb5d29c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 9 Dec 2023 08:26:29 -0800 Subject: [PATCH 12/19] Add simplified random implementation --- dask_expr/array/__init__.py | 1 + dask_expr/array/random.py | 1084 +++++++++++++++++++++++++++ dask_expr/array/tests/test_array.py | 5 + 3 files changed, 1090 insertions(+) create mode 100644 dask_expr/array/random.py diff --git a/dask_expr/array/__init__.py b/dask_expr/array/__init__.py index 5c04d131..a8f55d83 100644 --- a/dask_expr/array/__init__.py +++ b/dask_expr/array/__init__.py @@ -1 +1,2 @@ +from dask_expr.array import random from dask_expr.array.core import Array, from_array diff --git a/dask_expr/array/random.py b/dask_expr/array/random.py new file mode 100644 index 00000000..6eff3d0a --- /dev/null +++ b/dask_expr/array/random.py @@ -0,0 +1,1084 @@ +from __future__ import annotations + +import contextlib +import importlib +import numbers +from itertools import product +from numbers import Integral +from threading import Lock + +import numpy as np +from dask.array.backends import array_creation_dispatch +from dask.array.core import asarray, broadcast_shapes, normalize_chunks +from dask.array.creation import arange +from dask.array.utils import asarray_safe +from dask.base import tokenize +from dask.highlevelgraph import HighLevelGraph +from dask.utils import cached_property, derived_from, random_state_data, typename + +from dask_expr.array.core import IO, Array + + +class Generator: + """ + Container for the BitGenerators. + + ``Generator`` exposes a number of methods for generating random + numbers drawn from a variety of probability distributions and serves + as a replacement for ``RandomState``. The main difference between the + two is that ``Generator`` relies on an additional ``BitGenerator`` to + manage state and generate the random bits, which are then transformed + into random values from useful distributions. The default ``BitGenerator`` + used by ``Generator`` is ``PCG64``. The ``BitGenerator`` can be changed + by passing an instantiated ``BitGenerator`` to ``Generator``. + + The function :func:`dask.array.random.default_rng` is the recommended way + to instantiate a ``Generator``. + + .. warning:: + + No Compatibility Guarantee. + + ``Generator`` does not provide a version compatibility guarantee. In + particular, as better algorithms evolve the bit stream may change. + + Parameters + ---------- + bit_generator : BitGenerator + BitGenerator to use as the core generator. + + Notes + ----- + In addition to the distribution-specific arguments, each ``Generator`` + method takes a keyword argument `size` that defaults to ``None``. If + `size` is ``None``, then a single value is generated and returned. If + `size` is an integer, then a 1-D array filled with generated values is + returned. If `size` is a tuple, then an array with that shape is + filled and returned. + + The Python stdlib module `random` contains pseudo-random number generator + with a number of methods that are similar to the ones available in + ``Generator``. It uses Mersenne Twister, and this bit generator can + be accessed using ``MT19937``. ``Generator``, besides being + Dask-aware, has the advantage that it provides a much larger number + of probability distributions to choose from. + + All ``Generator`` methods are identical to ``np.random.Generator`` except + that they also take a `chunks=` keyword argument. + + ``Generator`` does not guarantee parity in the generated numbers + with any third party library. In particular, numbers generated by + `Dask` and `NumPy` will differ even if they use the same seed. + + Examples + -------- + >>> from numpy.random import PCG64 + >>> from dask.array.random import Generator + >>> rng = Generator(PCG64()) + >>> rng.standard_normal().compute() # doctest: +SKIP + array(0.44595957) # random + + See Also + -------- + default_rng : Recommended constructor for `Generator`. + np.random.Generator + """ + + def __init__(self, bit_generator): + self._bit_generator = bit_generator + + def __str__(self): + _str = self.__class__.__name__ + _str += "(" + self._bit_generator.__class__.__name__ + ")" + return _str + + @property + def _backend_name(self): + # Assumes typename(self._RandomState) starts with an + # array-library name (e.g. "numpy" or "cupy") + return typename(self._bit_generator).split(".")[0] + + @property + def _backend(self): + # Assumes `self._backend_name` is an importable + # array-library name (e.g. "numpy" or "cupy") + return importlib.import_module(self._backend_name) + + @derived_from(np.random.Generator, skipblocks=1) + def beta(self, a, b, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "beta", a, b, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def binomial(self, n, p, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "binomial", n, p, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def chisquare(self, df, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "chisquare", df, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def choice( + self, + a, + size=None, + replace=True, + p=None, + axis=0, + shuffle=True, + chunks="auto", + ): + ( + a, + size, + replace, + p, + axis, + chunks, + meta, + dependencies, + ) = _choice_validate_params(self, a, size, replace, p, axis, chunks) + + sizes = list(product(*chunks)) + bitgens = _spawn_bitgens(self._bit_generator, len(sizes)) + + name = "da.random.choice-%s" % tokenize( + bitgens, size, chunks, a, replace, p, axis, shuffle + ) + keys = product([name], *(range(len(bd)) for bd in chunks)) + dsk = { + k: (_choice_rng, bitgen, a, size, replace, p, axis, shuffle) + for k, bitgen, size in zip(keys, bitgens, sizes) + } + + graph = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies) + return Array(graph, name, chunks, meta=meta) + + @derived_from(np.random.Generator, skipblocks=1) + def exponential(self, scale=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "exponential", scale, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def f(self, dfnum, dfden, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "f", dfnum, dfden, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def gamma(self, shape, scale=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "gamma", shape, scale, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def geometric(self, p, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "geometric", p, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def gumbel(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "gumbel", loc, scale, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def hypergeometric(self, ngood, nbad, nsample, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, + "hypergeometric", + ngood, + nbad, + nsample, + size=size, + chunks=chunks, + **kwargs, + ) + + @derived_from(np.random.Generator, skipblocks=1) + def integers( + self, + low, + high=None, + size=None, + dtype=np.int64, + endpoint=False, + chunks="auto", + **kwargs, + ): + return _wrap_func( + self, + "integers", + low, + high=high, + size=size, + dtype=dtype, + endpoint=endpoint, + chunks=chunks, + **kwargs, + ) + + @derived_from(np.random.Generator, skipblocks=1) + def laplace(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "laplace", loc, scale, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def logistic(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "logistic", loc, scale, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def lognormal(self, mean=0.0, sigma=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "lognormal", mean, sigma, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def logseries(self, p, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "logseries", p, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def multinomial(self, n, pvals, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, + "multinomial", + n, + pvals, + size=size, + chunks=chunks, + extra_chunks=((len(pvals),),), + **kwargs, + ) + + @derived_from(np.random.Generator, skipblocks=1) + def multivariate_hypergeometric( + self, colors, nsample, size=None, method="marginals", chunks="auto", **kwargs + ): + return _wrap_func( + self, + "multivariate_hypergeometric", + colors, + nsample, + size=size, + method=method, + chunks=chunks, + **kwargs, + ) + + @derived_from(np.random.Generator, skipblocks=1) + def negative_binomial(self, n, p, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "negative_binomial", n, p, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def noncentral_chisquare(self, df, nonc, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "noncentral_chisquare", df, nonc, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def noncentral_f(self, dfnum, dfden, nonc, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "noncentral_f", dfnum, dfden, nonc, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def normal(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "normal", loc, scale, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def pareto(self, a, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "pareto", a, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def permutation(self, x): + from dask.array.slicing import shuffle_slice + + if self._backend_name == "cupy": + raise NotImplementedError( + "`Generator.permutation` not supported for cupy-backed " + "Generator objects. Use the 'numpy' array backend to " + "call `dask.array.random.default_rng`, or pass in " + " `numpy.random.PCG64()`." + ) + + if isinstance(x, numbers.Number): + x = arange(x, chunks="auto") + + index = self._backend.arange(len(x)) + _shuffle(self._bit_generator, index) + return shuffle_slice(x, index) + + @derived_from(np.random.Generator, skipblocks=1) + def poisson(self, lam=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "poisson", lam, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def power(self, a, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "power", a, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def random(self, size=None, dtype=np.float64, out=None, chunks="auto", **kwargs): + return _wrap_func( + self, "random", size=size, dtype=dtype, out=out, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def rayleigh(self, scale=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "rayleigh", scale, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def standard_cauchy(self, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "standard_cauchy", size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def standard_exponential(self, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "standard_exponential", size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def standard_gamma(self, shape, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "standard_gamma", shape, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def standard_normal(self, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "standard_normal", size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def standard_t(self, df, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "standard_t", df, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def triangular(self, left, mode, right, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "triangular", left, mode, right, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def uniform(self, low=0.0, high=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "uniform", low, high, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def vonmises(self, mu, kappa, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "vonmises", mu, kappa, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.Generator, skipblocks=1) + def wald(self, mean, scale, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "wald", mean, scale, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def weibull(self, a, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "weibull", a, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.Generator, skipblocks=1) + def zipf(self, a, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "zipf", a, size=size, chunks=chunks, **kwargs) + + +def default_rng(seed=None): + """ + Construct a new Generator with the default BitGenerator (PCG64). + + Parameters + ---------- + seed : {None, int, array_like[ints], SeedSequence, BitGenerator, Generator}, optional + A seed to initialize the `BitGenerator`. If None, then fresh, + unpredictable entropy will be pulled from the OS. If an ``int`` or + ``array_like[ints]`` is passed, then it will be passed to + `SeedSequence` to derive the initial `BitGenerator` state. One may + also pass in a `SeedSequence` instance. + Additionally, when passed a `BitGenerator`, it will be wrapped by + `Generator`. If passed a `Generator`, it will be returned unaltered. + + Returns + ------- + Generator + The initialized generator object. + + Notes + ----- + If ``seed`` is not a `BitGenerator` or a `Generator`, a new + `BitGenerator` is instantiated. This function does not manage a default + global instance. + + Examples + -------- + ``default_rng`` is the recommended constructor for the random number + class ``Generator``. Here are several ways we can construct a random + number generator using ``default_rng`` and the ``Generator`` class. + + Here we use ``default_rng`` to generate a random float: + + >>> import dask.array as da + >>> rng = da.random.default_rng(12345) + >>> print(rng) + Generator(PCG64) + >>> rfloat = rng.random().compute() + >>> rfloat + array(0.86999885) + >>> type(rfloat) + + + Here we use ``default_rng`` to generate 3 random integers between 0 + (inclusive) and 10 (exclusive): + + >>> import dask.array as da + >>> rng = da.random.default_rng(12345) + >>> rints = rng.integers(low=0, high=10, size=3).compute() + >>> rints + array([2, 8, 7]) + >>> type(rints[0]) + + + Here we specify a seed so that we have reproducible results: + + >>> import dask.array as da + >>> rng = da.random.default_rng(seed=42) + >>> print(rng) + Generator(PCG64) + >>> arr1 = rng.random((3, 3)).compute() + >>> arr1 + array([[0.91674416, 0.91098667, 0.8765925 ], + [0.30931841, 0.95465607, 0.17509458], + [0.99662814, 0.75203348, 0.15038118]]) + + If we exit and restart our Python interpreter, we'll see that we + generate the same random numbers again: + + >>> import dask.array as da + >>> rng = da.random.default_rng(seed=42) + >>> arr2 = rng.random((3, 3)).compute() + >>> arr2 + array([[0.91674416, 0.91098667, 0.8765925 ], + [0.30931841, 0.95465607, 0.17509458], + [0.99662814, 0.75203348, 0.15038118]]) + + See Also + -------- + np.random.default_rng + """ + if hasattr(seed, "capsule"): + # We are passed a BitGenerator, so just wrap it + return Generator(seed) + elif isinstance(seed, Generator): + # Pass through a Generator + return seed + elif hasattr(seed, "bit_generator"): + # a Generator. Just not ours + return Generator(seed.bit_generator) + # Otherwise, use the backend-default BitGenerator + return Generator(array_creation_dispatch.default_bit_generator(seed)) + + +class RandomState: + """ + Mersenne Twister pseudo-random number generator + + This object contains state to deterministically generate pseudo-random + numbers from a variety of probability distributions. It is identical to + ``np.random.RandomState`` except that all functions also take a ``chunks=`` + keyword argument. + + Parameters + ---------- + seed: Number + Object to pass to RandomState to serve as deterministic seed + RandomState: Callable[seed] -> RandomState + A callable that, when provided with a ``seed`` keyword provides an + object that operates identically to ``np.random.RandomState`` (the + default). This might also be a function that returns a + ``mkl_random``, or ``cupy.random.RandomState`` object. + + Examples + -------- + >>> import dask.array as da + >>> state = da.random.RandomState(1234) # a seed + >>> x = state.normal(10, 0.1, size=3, chunks=(2,)) + >>> x.compute() + array([10.01867852, 10.04812289, 9.89649746]) + + See Also + -------- + np.random.RandomState + """ + + def __init__(self, seed=None, RandomState=None): + self._numpy_state = np.random.RandomState(seed) + self._RandomState = ( + array_creation_dispatch.RandomState if RandomState is None else RandomState + ) + + @property + def _backend(self): + # Assumes typename(self._RandomState) starts with + # an importable array-library name (e.g. "numpy" or "cupy") + _backend_name = typename(self._RandomState).split(".")[0] + return importlib.import_module(_backend_name) + + def seed(self, seed=None): + self._numpy_state.seed(seed) + + @derived_from(np.random.RandomState, skipblocks=1) + def beta(self, a, b, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "beta", a, b, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def binomial(self, n, p, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "binomial", n, p, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def chisquare(self, df, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "chisquare", df, size=size, chunks=chunks, **kwargs) + + with contextlib.suppress(AttributeError): + + @derived_from(np.random.RandomState, skipblocks=1) + def choice(self, a, size=None, replace=True, p=None, chunks="auto"): + ( + a, + size, + replace, + p, + axis, # np.random.RandomState.choice does not use axis + chunks, + meta, + dependencies, + ) = _choice_validate_params(self, a, size, replace, p, 0, chunks) + + sizes = list(product(*chunks)) + state_data = random_state_data(len(sizes), self._numpy_state) + + name = "da.random.choice-%s" % tokenize( + state_data, size, chunks, a, replace, p + ) + keys = product([name], *(range(len(bd)) for bd in chunks)) + dsk = { + k: (_choice_rs, state, a, size, replace, p) + for k, state, size in zip(keys, state_data, sizes) + } + + graph = HighLevelGraph.from_collections( + name, dsk, dependencies=dependencies + ) + return Array(graph, name, chunks, meta=meta) + + @derived_from(np.random.RandomState, skipblocks=1) + def exponential(self, scale=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "exponential", scale, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def f(self, dfnum, dfden, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "f", dfnum, dfden, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def gamma(self, shape, scale=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "gamma", shape, scale, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def geometric(self, p, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "geometric", p, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def gumbel(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "gumbel", loc, scale, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def hypergeometric(self, ngood, nbad, nsample, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, + "hypergeometric", + ngood, + nbad, + nsample, + size=size, + chunks=chunks, + **kwargs, + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def laplace(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "laplace", loc, scale, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def logistic(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "logistic", loc, scale, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def lognormal(self, mean=0.0, sigma=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "lognormal", mean, sigma, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def logseries(self, p, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "logseries", p, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def multinomial(self, n, pvals, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, + "multinomial", + n, + pvals, + size=size, + chunks=chunks, + extra_chunks=((len(pvals),),), + **kwargs, + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def negative_binomial(self, n, p, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "negative_binomial", n, p, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def noncentral_chisquare(self, df, nonc, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "noncentral_chisquare", df, nonc, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def noncentral_f(self, dfnum, dfden, nonc, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "noncentral_f", dfnum, dfden, nonc, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def normal(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "normal", loc, scale, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def pareto(self, a, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "pareto", a, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def permutation(self, x): + from dask.array.slicing import shuffle_slice + + if isinstance(x, numbers.Number): + x = arange(x, chunks="auto") + + index = np.arange(len(x)) + self._numpy_state.shuffle(index) + return shuffle_slice(x, index) + + @derived_from(np.random.RandomState, skipblocks=1) + def poisson(self, lam=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "poisson", lam, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def power(self, a, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "power", a, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def randint(self, low, high=None, size=None, chunks="auto", dtype="l", **kwargs): + return _wrap_func( + self, "randint", low, high, size=size, chunks=chunks, dtype=dtype, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def random_integers(self, low, high=None, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "random_integers", low, high, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def random_sample(self, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "random_sample", size=size, chunks=chunks, **kwargs) + + random = random_sample + + @derived_from(np.random.RandomState, skipblocks=1) + def rayleigh(self, scale=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "rayleigh", scale, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def standard_cauchy(self, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "standard_cauchy", size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def standard_exponential(self, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "standard_exponential", size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def standard_gamma(self, shape, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "standard_gamma", shape, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def standard_normal(self, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "standard_normal", size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def standard_t(self, df, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "standard_t", df, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def tomaxint(self, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "tomaxint", size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def triangular(self, left, mode, right, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "triangular", left, mode, right, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def uniform(self, low=0.0, high=1.0, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "uniform", low, high, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def vonmises(self, mu, kappa, size=None, chunks="auto", **kwargs): + return _wrap_func( + self, "vonmises", mu, kappa, size=size, chunks=chunks, **kwargs + ) + + @derived_from(np.random.RandomState, skipblocks=1) + def wald(self, mean, scale, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "wald", mean, scale, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def weibull(self, a, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "weibull", a, size=size, chunks=chunks, **kwargs) + + @derived_from(np.random.RandomState, skipblocks=1) + def zipf(self, a, size=None, chunks="auto", **kwargs): + return _wrap_func(self, "zipf", a, size=size, chunks=chunks, **kwargs) + + +def _rng_from_bitgen(bitgen): + # Assumes typename(bitgen) starts with importable + # library name (e.g. "numpy" or "cupy") + backend_name = typename(bitgen).split(".")[0] + backend_lib = importlib.import_module(backend_name) + return backend_lib.random.default_rng(bitgen) + + +def _shuffle(bit_generator, x, axis=0): + state_data = bit_generator.state + bit_generator = type(bit_generator)() + bit_generator.state = state_data + state = _rng_from_bitgen(bit_generator) + return state.shuffle(x, axis=axis) + + +def _spawn_bitgens(bitgen, n_bitgens): + seeds = bitgen._seed_seq.spawn(n_bitgens) + bitgens = [type(bitgen)(seed) for seed in seeds] + return bitgens + + +def _apply_random_func(rng, funcname, bitgen, size, args, kwargs): + """Apply random module method with seed""" + if isinstance(bitgen, np.random.SeedSequence): + bitgen = rng(bitgen) + rng = _rng_from_bitgen(bitgen) + func = getattr(rng, funcname) + return func(*args, size=size, **kwargs) + + +def _apply_random(RandomState, funcname, state_data, size, args, kwargs): + """Apply RandomState method with seed""" + if RandomState is None: + RandomState = array_creation_dispatch.RandomState + state = RandomState(state_data) + func = getattr(state, funcname) + return func(*args, size=size, **kwargs) + + +def _choice_rng(state_data, a, size, replace, p, axis, shuffle): + state = _rng_from_bitgen(state_data) + return state.choice(a, size=size, replace=replace, p=p, axis=axis, shuffle=shuffle) + + +def _choice_rs(state_data, a, size, replace, p): + state = array_creation_dispatch.RandomState(state_data) + return state.choice(a, size=size, replace=replace, p=p) + + +def _choice_validate_params(state, a, size, replace, p, axis, chunks): + dependencies = [] + # Normalize and validate `a` + if isinstance(a, Integral): + if isinstance(state, Generator): + if state._backend_name == "cupy": + raise NotImplementedError( + "`choice` not supported for cupy-backed `Generator`." + ) + meta = state._backend.random.default_rng().choice(1, size=(), p=None) + elif isinstance(state, RandomState): + # On windows the output dtype differs if p is provided or + # # absent, see https://github.com/numpy/numpy/issues/9867 + dummy_p = state._backend.array([1]) if p is not None else p + meta = state._backend.random.RandomState().choice(1, size=(), p=dummy_p) + else: + raise ValueError("Unknown generator class") + len_a = a + if a < 0: + raise ValueError("a must be greater than 0") + else: + a = asarray(a) + a = a.rechunk(a.shape) + meta = a._meta + if a.ndim != 1: + raise ValueError("a must be one dimensional") + len_a = len(a) + dependencies.append(a) + a = a.__dask_keys__()[0] + + # Normalize and validate `p` + if p is not None: + if not isinstance(p, Array): + # If p is not a dask array, first check the sum is close + # to 1 before converting. + p = asarray_safe(p, like=p) + if not np.isclose(p.sum(), 1, rtol=1e-7, atol=0): + raise ValueError("probabilities do not sum to 1") + p = asarray(p) + else: + p = p.rechunk(p.shape) + + if p.ndim != 1: + raise ValueError("p must be one dimensional") + if len(p) != len_a: + raise ValueError("a and p must have the same size") + + dependencies.append(p) + p = p.__dask_keys__()[0] + + if size is None: + size = () + elif not isinstance(size, (tuple, list)): + size = (size,) + + if axis != 0: + raise ValueError("axis must be 0 since a is one dimensinal") + + chunks = normalize_chunks(chunks, size, dtype=np.float64) + if not replace and len(chunks[0]) > 1: + err_msg = ( + "replace=False is not currently supported for " + "dask.array.choice with multi-chunk output " + "arrays" + ) + raise NotImplementedError(err_msg) + + return a, size, replace, p, axis, chunks, meta, dependencies + + +def _wrap_func( + rng, funcname, *args, size=None, chunks="auto", extra_chunks=(), **kwargs +): + return Random(rng, funcname, size, chunks, extra_chunks, args, kwargs) + + +class Random(IO): + _parameters = [ + "rng", + "distribution", + "size", + "chunks", + "extra_chunks", + "args", + "kwargs", + ] + _defaults = {"extra_chunks": ()} + + @property + def chunks(self): + size = self.size + chunks = self.operand("chunks") + + if size is not None and not isinstance(size, (tuple, list)): + size = (size,) + + # shapes = list( + # { + # ar.shape + # for ar in chain(args, kwargs.values()) + # if isinstance(ar, (Array, np.ndarray)) + # } + # ) + # if size is not None: + # shapes.append(size) + shapes = [size] + # broadcast to the final size(shape) + size = broadcast_shapes(*shapes) + return normalize_chunks( + chunks, + size, # ideally would use dtype here + dtype=self.kwargs.get("dtype", np.float64), + ) + + @cached_property + def _info(self): + sizes = list(product(*self.chunks)) + if isinstance(self.rng, Generator): + bitgens = _spawn_bitgens(self.rng._bit_generator, len(sizes)) + bitgen_token = tokenize(bitgens) + bitgens = [_bitgen._seed_seq for _bitgen in bitgens] + func_applier = _apply_random_func + gen = type(self.rng._bit_generator) + elif isinstance(self.rng, RandomState): + bitgens = random_state_data(len(sizes), self.rng._numpy_state) + bitgen_token = tokenize(bitgens) + func_applier = _apply_random + gen = self.rng._RandomState + else: + raise TypeError( + "Unknown object type: Not a Generator and Not a RandomState" + ) + token = tokenize(bitgen_token, self.size, self.chunks, self.args, self.kwargs) + name = f"{self.distribution}-{token}" + + return bitgens, name, sizes, gen, func_applier + + @property + def _name(self): + return self._info[1] + + @property + def bitgens(self): + return self._info[0] + + def _layer(self): + bitgens, name, sizes, gen, func_applier = self._info + + keys = product( + [name], + *([range(len(bd)) for bd in self.chunks] + [[0]] * len(self.extra_chunks)), + ) + + vals = [] + # TODO: handle non-trivial args/kwargs (arrays, dask or otherwise) + for bitgen, size in zip(bitgens, sizes): + vals.append( + ( + func_applier, + gen, + self.distribution, + bitgen, + size, + self.args, + self.kwargs, + ) + ) + + return dict(zip(keys, vals)) + + @cached_property + def _meta(self): + bitgens, name, sizes, gen, func_applier = self._info + return func_applier( + gen, + self.distribution, + bitgens[0], # TODO: not sure about this + (0,) * len(self.size), + self.args, + self.kwargs, + # small_args, + # small_kwargs, + ) + + +""" +Lazy RNG-state machinery + +Many of the RandomState methods are exported as functions in da.random for +backward compatibility reasons. Their usage is discouraged. +Use da.random.default_rng() to get a Generator based rng and use its +methods instead. +""" + +_cached_states: dict[str, RandomState] = {} +_cached_states_lock = Lock() + + +def _make_api(attr): + def wrapper(*args, **kwargs): + key = array_creation_dispatch.backend + with _cached_states_lock: + try: + state = _cached_states[key] + except KeyError: + _cached_states[key] = state = RandomState() + return getattr(state, attr)(*args, **kwargs) + + wrapper.__name__ = getattr(RandomState, attr).__name__ + wrapper.__doc__ = getattr(RandomState, attr).__doc__ + return wrapper + + +""" +RandomState only +""" + +seed = _make_api("seed") + +beta = _make_api("beta") +binomial = _make_api("binomial") +chisquare = _make_api("chisquare") +choice = _make_api("choice") +exponential = _make_api("exponential") +f = _make_api("f") +gamma = _make_api("gamma") +geometric = _make_api("geometric") +gumbel = _make_api("gumbel") +hypergeometric = _make_api("hypergeometric") +laplace = _make_api("laplace") +logistic = _make_api("logistic") +lognormal = _make_api("lognormal") +logseries = _make_api("logseries") +multinomial = _make_api("multinomial") +negative_binomial = _make_api("negative_binomial") +noncentral_chisquare = _make_api("noncentral_chisquare") +noncentral_f = _make_api("noncentral_f") +normal = _make_api("normal") +pareto = _make_api("pareto") +permutation = _make_api("permutation") +poisson = _make_api("poisson") +power = _make_api("power") +random_sample = _make_api("random_sample") +random = _make_api("random_sample") +randint = _make_api("randint") +random_integers = _make_api("random_integers") +rayleigh = _make_api("rayleigh") +standard_cauchy = _make_api("standard_cauchy") +standard_exponential = _make_api("standard_exponential") +standard_gamma = _make_api("standard_gamma") +standard_normal = _make_api("standard_normal") +standard_t = _make_api("standard_t") +triangular = _make_api("triangular") +uniform = _make_api("uniform") +vonmises = _make_api("vonmises") +wald = _make_api("wald") +weibull = _make_api("weibull") +zipf = _make_api("zipf") diff --git a/dask_expr/array/tests/test_array.py b/dask_expr/array/tests/test_array.py index 022a85ab..b5d8db31 100644 --- a/dask_expr/array/tests/test_array.py +++ b/dask_expr/array/tests/test_array.py @@ -132,3 +132,8 @@ def test_xarray(): x = (xr.DataArray(b, dims=["x", "y"]) + 1).chunk(x=2) assert x.data.optimize()._name == (da.from_array(a, chunks={0: 2}) + 1)._name + + +def test_random(): + x = da.random.random((100, 100), chunks=(50, 50)) + assert_eq(x, x) From d2d334385827966b6e19a0781350dbaf74b584bc Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 9 Dec 2023 14:57:08 -0800 Subject: [PATCH 13/19] Build smarter Elemwise class, improve rechunk/slicing optimizations --- dask_expr/array/blockwise.py | 196 +++++++++++++++++----------- dask_expr/array/rechunk.py | 17 ++- dask_expr/array/slicing.py | 18 ++- dask_expr/array/tests/test_array.py | 6 + 4 files changed, 150 insertions(+), 87 deletions(-) diff --git a/dask_expr/array/blockwise.py b/dask_expr/array/blockwise.py index d3c65edd..42eb740b 100644 --- a/dask_expr/array/blockwise.py +++ b/dask_expr/array/blockwise.py @@ -14,7 +14,7 @@ from dask.base import is_dask_collection, tokenize from dask.blockwise import blockwise as core_blockwise from dask.delayed import unpack_collections -from dask.utils import funcname +from dask.utils import cached_property, funcname from dask_expr.array.core import Array @@ -355,9 +355,121 @@ def blockwise( class Elemwise(Blockwise): - # TODO: we should make this have simpler parameters. Optimizations like - # slicing will be more powerful I think. - pass + _parameters = ["op", "dtype", "name"] + _defaults = { + "dtype": None, + "name": None, + } + align_arrays = False + new_axes = {} + adjust_chunks = None + token = None + _meta_provided = None + concatenate = None + + @property + def elemwise_args(self): + return self.operands[len(self._parameters) :] + + @property + def out_ind(self): + shapes = [] + for arg in self.elemwise_args: + shape = getattr(arg, "shape", ()) + if any(is_dask_collection(x) for x in shape): + # Want to exclude Delayed shapes and dd.Scalar + shape = () + shapes.append(shape) + # if isinstance(where, Array): + # shapes.append(where.shape) + # if isinstance(out, Array): + # shapes.append(out.shape) + + shapes = [s if isinstance(s, Iterable) else () for s in shapes] + out_ndim = len( + broadcast_shapes(*shapes) + ) # Raises ValueError if dimensions mismatch + return tuple(range(out_ndim))[::-1] + + @cached_property + def _info(self): + if self.operand("dtype") is not None: + need_enforce_dtype = True + dtype = self.operand("dtype") + else: + # We follow NumPy's rules for dtype promotion, which special cases + # scalars and 0d ndarrays (which it considers equivalent) by using + # their values to compute the result dtype: + # https://github.com/numpy/numpy/issues/6240 + # We don't inspect the values of 0d dask arrays, because these could + # hold potentially very expensive calculations. Instead, we treat + # them just like other arrays, and if necessary cast the result of op + # to match. + vals = [ + np.empty((1,) * max(1, a.ndim), dtype=a.dtype) + if not is_scalar_for_elemwise(a) + else a + for a in self.elemwise_args + ] + try: + dtype = apply_infer_dtype( + self.op, vals, {}, "elemwise", suggest_dtype=False + ) + except Exception: + return NotImplemented + need_enforce_dtype = any( + not is_scalar_for_elemwise(a) and a.ndim == 0 + for a in self.elemwise_args + ) + + # TODO: add back + # if where is not True: + # blockwise_kwargs["elemwise_where_function"] = op + # op = _elemwise_handle_where + # args.extend([where, out]) + + if need_enforce_dtype: + blockwise_kwargs = { + "enforce_dtype": dtype, + "enforce_dtype_function": self.op, + } + op = _enforce_dtype + else: + blockwise_kwargs = {} + op = self.op + + return op, dtype, blockwise_kwargs + + @property + def func(self): + return self._info[0] + + @property + def dtype(self): + return self._info[1] + + @property + def kwargs(self): + return self._info[2] + + @property + def token(self): + return funcname(self.op).strip("_") + + @property + def args(self): + # for Blockwise rather than Elemwise + return tuple( + toolz.concat( + ( + a, + tuple(range(a.ndim)[::-1]) + if not is_scalar_for_elemwise(a) + else None, + ) + for a in self.elemwise_args + ) + ) def elemwise(op, *args, out=None, where=True, dtype=None, name=None, **kwargs): @@ -413,79 +525,7 @@ def elemwise(op, *args, out=None, where=True, dtype=None, name=None, **kwargs): args = [np.asarray(a) if isinstance(a, (list, tuple)) else a for a in args] - shapes = [] - for arg in args: - shape = getattr(arg, "shape", ()) - if any(is_dask_collection(x) for x in shape): - # Want to exclude Delayed shapes and dd.Scalar - shape = () - shapes.append(shape) - if isinstance(where, Array): - shapes.append(where.shape) - if isinstance(out, Array): - shapes.append(out.shape) - - shapes = [s if isinstance(s, Iterable) else () for s in shapes] - out_ndim = len( - broadcast_shapes(*shapes) - ) # Raises ValueError if dimensions mismatch - expr_inds = tuple(range(out_ndim))[::-1] - - if dtype is not None: - need_enforce_dtype = True - else: - # We follow NumPy's rules for dtype promotion, which special cases - # scalars and 0d ndarrays (which it considers equivalent) by using - # their values to compute the result dtype: - # https://github.com/numpy/numpy/issues/6240 - # We don't inspect the values of 0d dask arrays, because these could - # hold potentially very expensive calculations. Instead, we treat - # them just like other arrays, and if necessary cast the result of op - # to match. - vals = [ - np.empty((1,) * max(1, a.ndim), dtype=a.dtype) - if not is_scalar_for_elemwise(a) - else a - for a in args - ] - try: - dtype = apply_infer_dtype(op, vals, {}, "elemwise", suggest_dtype=False) - except Exception: - return NotImplemented - need_enforce_dtype = any( - not is_scalar_for_elemwise(a) and a.ndim == 0 for a in args - ) - - # if not name: - # name = f"{funcname(op)}-{tokenize(op, dtype, *args, where)}" - - blockwise_kwargs = dict(dtype=dtype, token=funcname(op).strip("_")) - - # TODO: add back - # if where is not True: - # blockwise_kwargs["elemwise_where_function"] = op - # op = _elemwise_handle_where - # args.extend([where, out]) - - if need_enforce_dtype: - blockwise_kwargs["enforce_dtype"] = dtype - blockwise_kwargs["enforce_dtype_function"] = op - op = _enforce_dtype - - result = blockwise( - op, - expr_inds, - *toolz.concat( - (a, tuple(range(a.ndim)[::-1]) if not is_scalar_for_elemwise(a) else None) - for a in args - ), - cls=Elemwise, - **blockwise_kwargs, - ) - - # TODO: handle out - # return handle_out(out, result) - return result + return Elemwise(op, dtype, name, *args) def broadcast_shapes(*shapes): @@ -558,7 +598,7 @@ def is_scalar_for_elemwise(arg): ) -class Transpose(Elemwise): +class Transpose(Blockwise): _parameters = ["array", "axes"] func = staticmethod(np.transpose) align_arrays = False diff --git a/dask_expr/array/rechunk.py b/dask_expr/array/rechunk.py index 48c91bf9..87d9dd9a 100644 --- a/dask_expr/array/rechunk.py +++ b/dask_expr/array/rechunk.py @@ -130,7 +130,7 @@ def _simplify_down(self): args = [] for arg, inds in toolz.partition_all(2, self.array.args): if inds is None: - args.extend((arg, inds)) + args.append(arg) else: assert isinstance(arg, Array) if isinstance(self._chunks, tuple): @@ -143,10 +143,21 @@ def _simplify_down(self): if j in self._chunks } arg = arg.rechunk(chunks) - args.extend((arg, inds)) + args.append(arg) return Elemwise(*self.array.operands[: -len(args)], *args) + if isinstance(self.array, Transpose): + if isinstance(self._chunks, tuple): + new = tuple(self._chunks[i] for i in self.array.axes) + elif isinstance(self._chunks, dict): + new = {self.array.axes.index[k]: v for k, v in self._chunks.items()} + else: + return None + return self.array.substitute( + self.array.array, self.array.array.rechunk(new) + ) + if isinstance(self.array, IO) and "chunks" in self.array._parameters: chunks = self._chunks if isinstance(chunks, tuple): @@ -225,4 +236,4 @@ def _compute_rechunk(old_name, old_chunks, chunks, level, name): return name, chunks, {**x2, **intermediates} -from dask_expr.array.blockwise import Elemwise +from dask_expr.array.blockwise import Elemwise, Transpose diff --git a/dask_expr/array/slicing.py b/dask_expr/array/slicing.py index 1f54e718..7dfad6d1 100644 --- a/dask_expr/array/slicing.py +++ b/dask_expr/array/slicing.py @@ -42,18 +42,24 @@ def _simplify_down(self): ), ) - if isinstance(self.array, Elemwise) and all( - isinstance(i, (slice, list)) for i in self.index - ): # TODO: deal with change in dimensionality + if isinstance(self.array, Elemwise): index = self.index + (slice(None),) * (self.ndim - len(self.index)) args = [] for arg, ind in toolz.partition(2, self.array.args): if ind is None: - args.extend((arg, ind)) + args.append(arg) else: idx = tuple(index[self.array.out_ind.index(i)] for i in ind) - args.extend((arg[idx], ind)) + args.append(arg[idx]) return Elemwise(*self.array.operands[: -len(args)], *args) + if isinstance(self.array, Transpose): + if any(isinstance(idx, (int)) or idx is None for idx in self.index): + return None # can't handle changes in dimension + else: + index = self.index + (slice(None),) * (self.ndim - len(self.index)) + new = tuple(index[i] for i in self.array.axes) + return self.array.substitute(self.array.array, self.array.array[new]) -from dask_expr.array.blockwise import Elemwise + +from dask_expr.array.blockwise import Elemwise, Transpose diff --git a/dask_expr/array/tests/test_array.py b/dask_expr/array/tests/test_array.py index b5d8db31..6616a2fc 100644 --- a/dask_expr/array/tests/test_array.py +++ b/dask_expr/array/tests/test_array.py @@ -36,6 +36,10 @@ def test_rechunk_optimize(): assert c.optimize()._name == d.optimize()._name + assert ( + b.T.rechunk((5, 2)).optimize()._name == da.from_array(a, chunks=(2, 5)).T._name + ) + def test_rechunk_blockwise_optimize(): a = np.random.random((10, 10)) @@ -114,6 +118,8 @@ def test_slicing_optimization(): assert b[5:, 4][::2].optimize()._name == b[5::2, 4].optimize()._name assert (b + 1)[:5].optimize()._name == (b[:5] + 1)._name + assert (b + 1)[5].optimize()._name == (b[5] + 1)._name + assert b.T[5:].optimize()._name == b[:, 5:].T._name @pytest.mark.xfail(reason="Blockwise specifies too much about dimension") From 76d6849037535f4be42ba3b387d5a63b1bd43498 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 9 Dec 2023 15:49:33 -0800 Subject: [PATCH 14/19] Add basic reduction implementation --- dask_expr/array/core.py | 21 ++ dask_expr/array/reductions.py | 415 ++++++++++++++++++++++++++++ dask_expr/array/tests/test_array.py | 8 + 3 files changed, 444 insertions(+) create mode 100644 dask_expr/array/reductions.py diff --git a/dask_expr/array/core.py b/dask_expr/array/core.py index 7cda84c1..788f7375 100644 --- a/dask_expr/array/core.py +++ b/dask_expr/array/core.py @@ -140,6 +140,27 @@ def __mul__(self, other): def __rmul__(self, other): return elemwise(operator.mul, other, self) + def sum(self, axis=None, dtype=None, keepdims=False, split_every=None, out=None): + """ + Return the sum of the array elements over the given axis. + + Refer to :func:`dask.array.sum` for full documentation. + + See Also + -------- + dask.array.sum : equivalent function + """ + from dask_expr.array.reductions import sum + + return sum( + self, + axis=axis, + dtype=dtype, + keepdims=keepdims, + split_every=split_every, + out=out, + ) + class IO(Array): pass diff --git a/dask_expr/array/reductions.py b/dask_expr/array/reductions.py new file mode 100644 index 00000000..bbbe6982 --- /dev/null +++ b/dask_expr/array/reductions.py @@ -0,0 +1,415 @@ +from __future__ import annotations + +import builtins +import math +from functools import partial +from itertools import product +from numbers import Integral + +import numpy as np +from dask import config +from dask.array import chunk +from dask.array.core import _concatenate2, asanyarray, broadcast_to +from dask.array.dispatch import divide_lookup, nannumel_lookup, numel_lookup +from dask.array.utils import compute_meta, is_arraylike, validate_axis +from dask.base import tokenize +from dask.blockwise import lol_tuples +from dask.utils import ( + cached_property, + derived_from, + funcname, + getargspec, + is_series_like, +) +from tlz import compose, get, partition_all + +from dask_expr.array.core import Array + + +# TODO: it would be good to have a higher level reduction operation that +# lowered down into the partial reduces below. +# It might also make sense to wrap all of the partial reduces into a single +# TreeReduce object. They pollute the expression a bit. +def reduction( + x, + chunk, + aggregate, + axis=None, + keepdims=False, + dtype=None, + split_every=None, + combine=None, + name=None, + out=None, + concatenate=True, + output_size=1, + meta=None, + weights=None, +): + """General version of reductions + + Parameters + ---------- + x: Array + Data being reduced along one or more axes + chunk: callable(x_chunk, [weights_chunk=None], axis, keepdims) + First function to be executed when resolving the dask graph. + This function is applied in parallel to all original chunks of x. + See below for function parameters. + combine: callable(x_chunk, axis, keepdims), optional + Function used for intermediate recursive aggregation (see + split_every below). If omitted, it defaults to aggregate. + If the reduction can be performed in less than 3 steps, it will not + be invoked at all. + aggregate: callable(x_chunk, axis, keepdims) + Last function to be executed when resolving the dask graph, + producing the final output. It is always invoked, even when the reduced + Array counts a single chunk along the reduced axes. + axis: int or sequence of ints, optional + Axis or axes to aggregate upon. If omitted, aggregate along all axes. + keepdims: boolean, optional + Whether the reduction function should preserve the reduced axes, + leaving them at size ``output_size``, or remove them. + dtype: np.dtype + data type of output. This argument was previously optional, but + leaving as ``None`` will now raise an exception. + split_every: int >= 2 or dict(axis: int), optional + Determines the depth of the recursive aggregation. If set to or more + than the number of input chunks, the aggregation will be performed in + two steps, one ``chunk`` function per input chunk and a single + ``aggregate`` function at the end. If set to less than that, an + intermediate ``combine`` function will be used, so that any one + ``combine`` or ``aggregate`` function has no more than ``split_every`` + inputs. The depth of the aggregation graph will be + :math:`log_{split_every}(input chunks along reduced axes)`. Setting to + a low value can reduce cache size and network transfers, at the cost of + more CPU and a larger dask graph. + + Omit to let dask heuristically decide a good default. A default can + also be set globally with the ``split_every`` key in + :mod:`dask.config`. + name: str, optional + Prefix of the keys of the intermediate and output nodes. If omitted it + defaults to the function names. + out: Array, optional + Another dask array whose contents will be replaced. Omit to create a + new one. Note that, unlike in numpy, this setting gives no performance + benefits whatsoever, but can still be useful if one needs to preserve + the references to a previously existing Array. + concatenate: bool, optional + If True (the default), the outputs of the ``chunk``/``combine`` + functions are concatenated into a single np.array before being passed + to the ``combine``/``aggregate`` functions. If False, the input of + ``combine`` and ``aggregate`` will be either a list of the raw outputs + of the previous step or a single output, and the function will have to + concatenate it itself. It can be useful to set this to False if the + chunk and/or combine steps do not produce np.arrays. + output_size: int >= 1, optional + Size of the output of the ``aggregate`` function along the reduced + axes. Ignored if keepdims is False. + weights : array_like, optional + Weights to be used in the reduction of `x`. Will be + automatically broadcast to the shape of `x`, and so must have + a compatible shape. For instance, if `x` has shape ``(3, 4)`` + then acceptable shapes for `weights` are ``(3, 4)``, ``(4,)``, + ``(3, 1)``, ``(1, 1)``, ``(1)``, and ``()``. + + Returns + ------- + dask array + + **Function Parameters** + + x_chunk: numpy.ndarray + Individual input chunk. For ``chunk`` functions, it is one of the + original chunks of x. For ``combine`` and ``aggregate`` functions, it's + the concatenation of the outputs produced by the previous ``chunk`` or + ``combine`` functions. If concatenate=False, it's a list of the raw + outputs from the previous functions. + weights_chunk: numpy.ndarray, optional + Only applicable to the ``chunk`` function. Weights, with the + same shape as `x_chunk`, to be applied during the reduction of + the individual input chunk. If ``weights`` have not been + provided then the function may omit this parameter. When + `weights_chunk` is included then it must occur immediately + after the `x_chunk` parameter, and must also have a default + value for cases when ``weights`` are not provided. + axis: tuple + Normalized list of axes to reduce upon, e.g. ``(0, )`` + Scalar, negative, and None axes have been normalized away. + Note that some numpy reduction functions cannot reduce along multiple + axes at once and strictly require an int in input. Such functions have + to be wrapped to cope. + keepdims: bool + Whether the reduction function should preserve the reduced axes or + remove them. + + """ + if axis is None: + axis = tuple(range(x.ndim)) + if isinstance(axis, Integral): + axis = (axis,) + axis = validate_axis(axis, x.ndim) + + if dtype is None: + raise ValueError("Must specify dtype") + if "dtype" in getargspec(chunk).args: + chunk = partial(chunk, dtype=dtype) + if "dtype" in getargspec(aggregate).args: + aggregate = partial(aggregate, dtype=dtype) + if is_series_like(x): + x = x.values + + # Map chunk across all blocks + inds = tuple(range(x.ndim)) + + args = (x, inds) + + # TODO: I'm ignoring this for now + if weights is not None: + # Broadcast weights to x and add to args + wgt = asanyarray(weights) + try: + wgt = broadcast_to(wgt, x.shape) + except ValueError: + raise ValueError( + f"Weights with shape {wgt.shape} are not broadcastable " + f"to x with shape {x.shape}" + ) + + args += (wgt, inds) + + # The dtype of `tmp` doesn't actually matter, and may be incorrect. + tmp = blockwise( + chunk, inds, *args, axis=axis, keepdims=True, token=name, dtype=dtype or float + ) + # TODO: this is going to be strange + tmp._chunks = tuple( + (output_size,) * len(c) if i in axis else c for i, c in enumerate(tmp.chunks) + ) + + if meta is None and hasattr(x, "_meta"): + try: + reduced_meta = compute_meta( + chunk, x.dtype, x._meta, axis=axis, keepdims=True, computing_meta=True + ) + except TypeError: + reduced_meta = compute_meta( + chunk, x.dtype, x._meta, axis=axis, keepdims=True + ) + except ValueError: + pass + else: + reduced_meta = None + + result = _tree_reduce( + tmp, + aggregate, + axis, + keepdims, + dtype, + split_every, + combine, + name=name, + concatenate=concatenate, + reduced_meta=reduced_meta, + ) + # TODO: forced chunks + if keepdims and output_size != 1: + result._chunks = tuple( + (output_size,) if i in axis else c for i, c in enumerate(tmp.chunks) + ) + # TODO: forced meta + if meta is not None: + result._meta = meta + return result + # return handle_out(out, result) + + +def _tree_reduce( + x, + aggregate, + axis, + keepdims, + dtype, + split_every=None, + combine=None, + name=None, + concatenate=True, + reduced_meta=None, +): + """Perform the tree reduction step of a reduction. + + Lower level, users should use ``reduction`` or ``arg_reduction`` directly. + """ + # Normalize split_every + split_every = split_every or config.get("split_every", 16) + if isinstance(split_every, dict): + split_every = {k: split_every.get(k, 2) for k in axis} + elif isinstance(split_every, Integral): + n = builtins.max(int(split_every ** (1 / (len(axis) or 1))), 2) + split_every = dict.fromkeys(axis, n) + else: + raise ValueError("split_every must be a int or a dict") + + # Reduce across intermediates + depth = 1 + for i, n in enumerate(x.numblocks): + if i in split_every and split_every[i] != 1: + depth = int(builtins.max(depth, math.ceil(math.log(n, split_every[i])))) + func = partial(combine or aggregate, axis=axis, keepdims=True) + if concatenate: + func = compose(func, partial(_concatenate2, axes=sorted(axis))) + for _ in range(depth - 1): + x = PartialReduce( + x, + func, + split_every, + True, + dtype=dtype, + name=(name or funcname(combine or aggregate)) + "-partial", + reduced_meta=reduced_meta, + ) + func = partial(aggregate, axis=axis, keepdims=keepdims) + if concatenate: + func = compose(func, partial(_concatenate2, axes=sorted(axis))) + return PartialReduce( + x, + func, + split_every, + keepdims=keepdims, + dtype=dtype, + name=(name or funcname(aggregate)) + "-aggregate", + reduced_meta=reduced_meta, + ) + + +class PartialReduce(Array): + _parameters = [ + "array", + "func", + "split_every", + "keepdims", + "dtype", + "name", + "reduced_meta", + ] + _defaults = { + "keepdims": False, + "dtype": None, + "name": None, + "reduced_meta": None, + } + + @cached_property + def _name(self): + return ( + (self.operand("name") or funcname(self.func)) + + "-" + + tokenize( + self.func, self.array, self.split_every, self.keepdims, self.dtype + ) + ) + + @cached_property + def chunks(self): + chunks = [ + tuple(1 for p in partition_all(self.split_every[i], c)) + if i in self.split_every + else c + for (i, c) in enumerate(self.array.chunks) + ] + + if not self.keepdims: + out_axis = [i for i in range(self.array.ndim) if i not in self.split_every] + getter = lambda k: get(out_axis, k) + chunks = list(getter(chunks)) + + return tuple(chunks) + + def _layer(self): + x = self.array + parts = [ + list(partition_all(self.split_every.get(i, 1), range(n))) + for (i, n) in enumerate(x.numblocks) + ] + keys = product(*map(range, map(len, parts))) + if not self.keepdims: + out_axis = [i for i in range(x.ndim) if i not in self.split_every] + getter = lambda k: get(out_axis, k) + keys = map(getter, keys) + dsk = {} + for k, p in zip(keys, product(*parts)): + free = { + i: j[0] + for (i, j) in enumerate(p) + if len(j) == 1 and i not in self.split_every + } + dummy = dict(i for i in enumerate(p) if i[0] in self.split_every) + g = lol_tuples((x.name,), range(x.ndim), free, dummy) + dsk[(self._name,) + k] = (self.func, g) + + return dsk + + @property + def _meta(self): + meta = self.array._meta + if self.reduced_meta is not None: + try: + meta = self.func(self.reduced_meta, computing_meta=True) + # no meta keyword argument exists for func, and it isn't required + except TypeError: + try: + meta = self.func(self.reduced_meta) + except ValueError as e: + # min/max functions have no identity, don't apply function to meta + if "zero-size array to reduction operation" in str(e): + meta = self.reduced_meta + # when no work can be computed on the empty array (e.g., func is a ufunc) + except ValueError: + pass + + # some functions can't compute empty arrays (those for which reduced_meta + # fall into the ValueError exception) and we have to rely on reshaping + # the array according to len(out_chunks) + if is_arraylike(meta) and meta.ndim != len(self.chunks): + if len(self.chunks) == 0: + meta = meta.sum() + else: + meta = meta.reshape((0,) * len(self.chunks)) + + return meta + + +@derived_from(np) +def sum(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None): + if dtype is None: + dtype = getattr(np.zeros(1, dtype=a.dtype).sum(), "dtype", object) + result = reduction( + a, + chunk.sum, + chunk.sum, + axis=axis, + keepdims=keepdims, + dtype=dtype, + split_every=split_every, + out=out, + ) + return result + + +def divide(a, b, dtype=None): + key = lambda x: getattr(x, "__array_priority__", float("-inf")) + f = divide_lookup.dispatch(type(builtins.max(a, b, key=key))) + return f(a, b, dtype=dtype) + + +def numel(x, **kwargs): + return numel_lookup(x, **kwargs) + + +def nannumel(x, **kwargs): + return nannumel_lookup(x, **kwargs) + + +from dask_expr.array.blockwise import blockwise diff --git a/dask_expr/array/tests/test_array.py b/dask_expr/array/tests/test_array.py index 6616a2fc..acf23ccd 100644 --- a/dask_expr/array/tests/test_array.py +++ b/dask_expr/array/tests/test_array.py @@ -143,3 +143,11 @@ def test_xarray(): def test_random(): x = da.random.random((100, 100), chunks=(50, 50)) assert_eq(x, x) + + +def test_reductions(): + a = np.random.random((10, 20)) + b = da.from_array(a, chunks=(2, 5)) + + assert_eq(b.sum(), a.sum()) + assert_eq(b.sum(axis=1), a.sum(axis=1)) From 2bb801fa08f28d1c954659cc8efee99d91100de3 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 17 Dec 2023 18:47:20 -0800 Subject: [PATCH 15/19] Support compute and persist methods --- dask_expr/array/core.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/dask_expr/array/core.py b/dask_expr/array/core.py index 788f7375..9a3c4e88 100644 --- a/dask_expr/array/core.py +++ b/dask_expr/array/core.py @@ -25,12 +25,21 @@ def __dask_postcompute__(self): def __dask_postpersist__(self): return FromGraph, (self._meta, self.chunks, self._name) + def compute(self, **kwargs): + return DaskMethodsMixin.compute(self.simplify(), **kwargs) + + def persist(self, **kwargs): + return DaskMethodsMixin.persist(self.simplify(), **kwargs) + def __array_ufunc__(self, numpy_ufunc, method, *inputs, **kwargs): raise NotImplementedError() def __array_function__(self, *args, **kwargs): raise NotImplementedError() + def __array__(self): + return self.compute() + def __getitem__(self, index): from dask.array.slicing import normalize_index From 94d46f96ec5964072750277a57a33cefbb344eb0 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Mon, 18 Dec 2023 14:14:32 -0800 Subject: [PATCH 16/19] Add ufuncs, operators, and more reductions --- dask_expr/array/__init__.py | 12 + dask_expr/array/core.py | 249 +++++++++++++ dask_expr/array/random.py | 11 +- dask_expr/array/reductions.py | 538 +++++++++++++++++++++++++++- dask_expr/array/tests/test_array.py | 58 ++- 5 files changed, 856 insertions(+), 12 deletions(-) diff --git a/dask_expr/array/__init__.py b/dask_expr/array/__init__.py index a8f55d83..8a8e6716 100644 --- a/dask_expr/array/__init__.py +++ b/dask_expr/array/__init__.py @@ -1,2 +1,14 @@ from dask_expr.array import random from dask_expr.array.core import Array, from_array +from dask_expr.array.reductions import ( + mean, + moment, + nanmean, + nanstd, + nansum, + nanvar, + prod, + std, + sum, + var, +) diff --git a/dask_expr/array/core.py b/dask_expr/array/core.py index 9a3c4e88..e6885338 100644 --- a/dask_expr/array/core.py +++ b/dask_expr/array/core.py @@ -2,6 +2,7 @@ from typing import Union import dask.array as da +import numpy as np from dask.base import DaskMethodsMixin, named_schedulers from dask.utils import cached_cumsum, cached_property from toolz import reduce @@ -149,6 +150,137 @@ def __mul__(self, other): def __rmul__(self, other): return elemwise(operator.mul, other, self) + def __sub__(self, other): + return elemwise(operator.sub, self, other) + + def __rsub__(self, other): + return elemwise(operator.sub, other, self) + + def __pow__(self, other): + return elemwise(operator.pow, self, other) + + def __rpow__(self, other): + return elemwise(operator.pow, other, self) + + def __truediv__(self, other): + return elemwise(operator.truediv, self, other) + + def __rtruediv__(self, other): + return elemwise(operator.truediv, other, self) + + def __floordiv__(self, other): + return elemwise(operator.floordiv, self, other) + + def __rfloordiv__(self, other): + return elemwise(operator.floordiv, other, self) + + def __array_ufunc__(self, numpy_ufunc, method, *inputs, **kwargs): + out = kwargs.get("out", ()) + for x in inputs + out: + if da.core._should_delegate(self, x): + return NotImplemented + + if method == "__call__": + if numpy_ufunc is np.matmul: + return NotImplemented + if numpy_ufunc.signature is not None: + return NotImplemented + if numpy_ufunc.nout > 1: + return NotImplemented + else: + return elemwise(numpy_ufunc, *inputs, **kwargs) + elif method == "outer": + return NotImplemented + else: + return NotImplemented + + @cached_property + def size(self): + """Number of elements in array""" + return reduce(operator.mul, self.shape, 1) + + def any(self, axis=None, keepdims=False, split_every=None, out=None): + """Returns True if any of the elements evaluate to True. + + Refer to :func:`dask.array.any` for full documentation. + + See Also + -------- + dask.array.any : equivalent function + """ + from dask_expr.array.reductions import any + + return any(self, axis=axis, keepdims=keepdims, split_every=split_every, out=out) + + def all(self, axis=None, keepdims=False, split_every=None, out=None): + """Returns True if all elements evaluate to True. + + Refer to :func:`dask.array.all` for full documentation. + + See Also + -------- + dask.array.all : equivalent function + """ + from dask_expr.array.reductions import all + + return all(self, axis=axis, keepdims=keepdims, split_every=split_every, out=out) + + def min(self, axis=None, keepdims=False, split_every=None, out=None): + """Return the minimum along a given axis. + + Refer to :func:`dask.array.min` for full documentation. + + See Also + -------- + dask.array.min : equivalent function + """ + from dask_expr.array.reductions import min + + return min(self, axis=axis, keepdims=keepdims, split_every=split_every, out=out) + + def max(self, axis=None, keepdims=False, split_every=None, out=None): + """Return the maximum along a given axis. + + Refer to :func:`dask.array.max` for full documentation. + + See Also + -------- + dask.array.max : equivalent function + """ + from dask_expr.array.reductions import max + + return max(self, axis=axis, keepdims=keepdims, split_every=split_every, out=out) + + def argmin(self, axis=None, *, keepdims=False, split_every=None, out=None): + """Return indices of the minimum values along the given axis. + + Refer to :func:`dask.array.argmin` for full documentation. + + See Also + -------- + dask.array.argmin : equivalent function + """ + from dask_expr.array.reductions import argmin + + return argmin( + self, axis=axis, keepdims=keepdims, split_every=split_every, out=out + ) + + def argmax(self, axis=None, *, keepdims=False, split_every=None, out=None): + """Return indices of the maximum values along the given axis. + + Refer to :func:`dask.array.argmax` for full documentation. + + See Also + -------- + dask.array.argmax : equivalent function + """ + from dask_expr.array.reductions import argmax + + return argmax( + self, axis=axis, keepdims=keepdims, split_every=split_every, out=out + ) + def sum(self, axis=None, dtype=None, keepdims=False, split_every=None, out=None): """ Return the sum of the array elements over the given axis. @@ -170,6 +302,123 @@ def sum(self, axis=None, dtype=None, keepdims=False, split_every=None, out=None) out=out, ) + def mean(self, axis=None, dtype=None, keepdims=False, split_every=None, out=None): + """Returns the average of the array elements along given axis. + + Refer to :func:`dask.array.mean` for full documentation. + + See Also + -------- + dask.array.mean : equivalent function + """ + from dask_expr.array.reductions import mean + + return mean( + self, + axis=axis, + dtype=dtype, + keepdims=keepdims, + split_every=split_every, + out=out, + ) + + def std( + self, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None + ): + """Returns the standard deviation of the array elements along given axis. + + Refer to :func:`dask.array.std` for full documentation. + + See Also + -------- + dask.array.std : equivalent function + """ + from dask_expr.array.reductions import std + + return std( + self, + axis=axis, + dtype=dtype, + keepdims=keepdims, + ddof=ddof, + split_every=split_every, + out=out, + ) + + def var( + self, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None + ): + """Returns the variance of the array elements, along given axis. + + Refer to :func:`dask.array.var` for full documentation. + + See Also + -------- + dask.array.var : equivalent function + """ + from dask_expr.array.reductions import var + + return var( + self, + axis=axis, + dtype=dtype, + keepdims=keepdims, + ddof=ddof, + split_every=split_every, + out=out, + ) + + def moment( + self, + order, + axis=None, + dtype=None, + keepdims=False, + ddof=0, + split_every=None, + out=None, + ): + """Calculate the nth centralized moment. + + Refer to :func:`dask.array.moment` for the full documentation. + + See Also + -------- + dask.array.moment : equivalent function + """ + from dask_expr.array.reductions import moment + + return moment( + self, + order, + axis=axis, + dtype=dtype, + keepdims=keepdims, + ddof=ddof, + split_every=split_every, + out=out, + ) + + def prod(self, axis=None, dtype=None, keepdims=False, split_every=None, out=None): + """Return the product of the array elements over the given axis + + Refer to :func:`dask.array.prod` for full documentation. + + See Also + -------- + dask.array.prod : equivalent function + """ + from dask_expr.array.reductions import prod + + return prod( + self, + axis=axis, + dtype=dtype, + keepdims=keepdims, + split_every=split_every, + out=out, + ) + class IO(Array): pass diff --git a/dask_expr/array/random.py b/dask_expr/array/random.py index 6eff3d0a..de307a6d 100644 --- a/dask_expr/array/random.py +++ b/dask_expr/array/random.py @@ -876,8 +876,6 @@ def _choice_validate_params(state, a, size, replace, p, axis, chunks): if size is None: size = () - elif not isinstance(size, (tuple, list)): - size = (size,) if axis != 0: raise ValueError("axis must be 0 since a is one dimensinal") @@ -897,6 +895,8 @@ def _choice_validate_params(state, a, size, replace, p, axis, chunks): def _wrap_func( rng, funcname, *args, size=None, chunks="auto", extra_chunks=(), **kwargs ): + if size is not None and not isinstance(size, (tuple, list)): + size = (size,) return Random(rng, funcname, size, chunks, extra_chunks, args, kwargs) @@ -914,12 +914,9 @@ class Random(IO): @property def chunks(self): - size = self.size + size = self.operand("size") chunks = self.operand("chunks") - if size is not None and not isinstance(size, (tuple, list)): - size = (size,) - # shapes = list( # { # ar.shape @@ -1001,7 +998,7 @@ def _meta(self): gen, self.distribution, bitgens[0], # TODO: not sure about this - (0,) * len(self.size), + (0,) * len(self.operand("size")), self.args, self.kwargs, # small_args, diff --git a/dask_expr/array/reductions.py b/dask_expr/array/reductions.py index bbbe6982..2ccbeb31 100644 --- a/dask_expr/array/reductions.py +++ b/dask_expr/array/reductions.py @@ -4,18 +4,20 @@ import math from functools import partial from itertools import product -from numbers import Integral +from numbers import Integral, Number import numpy as np from dask import config from dask.array import chunk -from dask.array.core import _concatenate2, asanyarray, broadcast_to +from dask.array.core import _concatenate2, asanyarray, broadcast_to, implements from dask.array.dispatch import divide_lookup, nannumel_lookup, numel_lookup +from dask.array.reductions import array_safe from dask.array.utils import compute_meta, is_arraylike, validate_axis from dask.base import tokenize from dask.blockwise import lol_tuples from dask.utils import ( cached_property, + deepmap, derived_from, funcname, getargspec, @@ -398,6 +400,118 @@ def sum(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None): return result +@derived_from(np) +def prod(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None): + if dtype is not None: + dt = dtype + else: + dt = getattr(np.ones((1,), dtype=a.dtype).prod(), "dtype", object) + return reduction( + a, + chunk.prod, + chunk.prod, + axis=axis, + keepdims=keepdims, + dtype=dt, + split_every=split_every, + out=out, + ) + + +@implements(np.min, np.amin) +@derived_from(np) +def min(a, axis=None, keepdims=False, split_every=None, out=None): + return reduction( + a, + chunk_min, + chunk.min, + combine=chunk_min, + axis=axis, + keepdims=keepdims, + dtype=a.dtype, + split_every=split_every, + out=out, + ) + + +def chunk_min(x, axis=None, keepdims=None): + """Version of np.min which ignores size 0 arrays""" + if x.size == 0: + return array_safe([], x, ndmin=x.ndim, dtype=x.dtype) + else: + return np.min(x, axis=axis, keepdims=keepdims) + + +@implements(np.max, np.amax) +@derived_from(np) +def max(a, axis=None, keepdims=False, split_every=None, out=None): + return reduction( + a, + chunk_max, + chunk.max, + combine=chunk_max, + axis=axis, + keepdims=keepdims, + dtype=a.dtype, + split_every=split_every, + out=out, + ) + + +def chunk_max(x, axis=None, keepdims=None): + """Version of np.max which ignores size 0 arrays""" + if x.size == 0: + return array_safe([], x, ndmin=x.ndim, dtype=x.dtype) + else: + return np.max(x, axis=axis, keepdims=keepdims) + + +@derived_from(np) +def any(a, axis=None, keepdims=False, split_every=None, out=None): + return reduction( + a, + chunk.any, + chunk.any, + axis=axis, + keepdims=keepdims, + dtype="bool", + split_every=split_every, + out=out, + ) + + +@derived_from(np) +def all(a, axis=None, keepdims=False, split_every=None, out=None): + return reduction( + a, + chunk.all, + chunk.all, + axis=axis, + keepdims=keepdims, + dtype="bool", + split_every=split_every, + out=out, + ) + + +@derived_from(np) +def nansum(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None): + if dtype is not None: + dt = dtype + else: + dt = getattr(chunk.nansum(np.ones((1,), dtype=a.dtype)), "dtype", object) + return reduction( + a, + chunk.nansum, + chunk.sum, + axis=axis, + keepdims=keepdims, + dtype=dt, + split_every=split_every, + out=out, + ) + + def divide(a, b, dtype=None): key = lambda x: getattr(x, "__array_priority__", float("-inf")) f = divide_lookup.dispatch(type(builtins.max(a, b, key=key))) @@ -412,4 +526,424 @@ def nannumel(x, **kwargs): return nannumel_lookup(x, **kwargs) +def mean_chunk( + x, sum=chunk.sum, numel=numel, dtype="f8", computing_meta=False, **kwargs +): + if computing_meta: + return x + n = numel(x, dtype=dtype, **kwargs) + + total = sum(x, dtype=dtype, **kwargs) + + return {"n": n, "total": total} + + +def mean_combine( + pairs, + sum=chunk.sum, + numel=numel, + dtype="f8", + axis=None, + computing_meta=False, + **kwargs, +): + if not isinstance(pairs, list): + pairs = [pairs] + + ns = deepmap(lambda pair: pair["n"], pairs) if not computing_meta else pairs + n = _concatenate2(ns, axes=axis).sum(axis=axis, **kwargs) + + if computing_meta: + return n + + totals = deepmap(lambda pair: pair["total"], pairs) + total = _concatenate2(totals, axes=axis).sum(axis=axis, **kwargs) + + return {"n": n, "total": total} + + +def mean_agg(pairs, dtype="f8", axis=None, computing_meta=False, **kwargs): + ns = deepmap(lambda pair: pair["n"], pairs) if not computing_meta else pairs + n = _concatenate2(ns, axes=axis) + n = np.sum(n, axis=axis, dtype=dtype, **kwargs) + + if computing_meta: + return n + + totals = deepmap(lambda pair: pair["total"], pairs) + total = _concatenate2(totals, axes=axis).sum(axis=axis, dtype=dtype, **kwargs) + + with np.errstate(divide="ignore", invalid="ignore"): + return divide(total, n, dtype=dtype) + + +@derived_from(np) +def mean(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None): + if dtype is not None: + dt = dtype + elif a.dtype == object: + dt = object + else: + dt = getattr(np.mean(np.zeros(shape=(1,), dtype=a.dtype)), "dtype", object) + return reduction( + a, + mean_chunk, + mean_agg, + axis=axis, + keepdims=keepdims, + dtype=dt, + split_every=split_every, + combine=mean_combine, + out=out, + concatenate=False, + ) + + +@derived_from(np) +def nanmean(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None): + if dtype is not None: + dt = dtype + else: + dt = getattr(np.mean(np.ones(shape=(1,), dtype=a.dtype)), "dtype", object) + return reduction( + a, + partial(mean_chunk, sum=chunk.nansum, numel=nannumel), + mean_agg, + axis=axis, + keepdims=keepdims, + dtype=dt, + split_every=split_every, + out=out, + concatenate=False, + combine=partial(mean_combine, sum=chunk.nansum, numel=nannumel), + ) + + +def moment_chunk( + A, + order=2, + sum=chunk.sum, + numel=numel, + dtype="f8", + computing_meta=False, + implicit_complex_dtype=False, + **kwargs, +): + if computing_meta: + return A + n = numel(A, **kwargs) + + n = n.astype(np.int64) + if implicit_complex_dtype: + total = sum(A, **kwargs) + else: + total = sum(A, dtype=dtype, **kwargs) + + with np.errstate(divide="ignore", invalid="ignore"): + u = total / n + d = A - u + if np.issubdtype(A.dtype, np.complexfloating): + d = np.abs(d) + xs = [sum(d**i, dtype=dtype, **kwargs) for i in range(2, order + 1)] + M = np.stack(xs, axis=-1) + return {"total": total, "n": n, "M": M} + + +def _moment_helper(Ms, ns, inner_term, order, sum, axis, kwargs): + M = Ms[..., order - 2].sum(axis=axis, **kwargs) + sum( + ns * inner_term**order, axis=axis, **kwargs + ) + for k in range(1, order - 1): + coeff = math.factorial(order) / (math.factorial(k) * math.factorial(order - k)) + M += coeff * sum(Ms[..., order - k - 2] * inner_term**k, axis=axis, **kwargs) + return M + + +def moment_combine( + pairs, + order=2, + ddof=0, + dtype="f8", + sum=np.sum, + axis=None, + computing_meta=False, + **kwargs, +): + if not isinstance(pairs, list): + pairs = [pairs] + + kwargs["dtype"] = None + kwargs["keepdims"] = True + + ns = deepmap(lambda pair: pair["n"], pairs) if not computing_meta else pairs + ns = _concatenate2(ns, axes=axis) + n = ns.sum(axis=axis, **kwargs) + + if computing_meta: + return n + + totals = _concatenate2(deepmap(lambda pair: pair["total"], pairs), axes=axis) + Ms = _concatenate2(deepmap(lambda pair: pair["M"], pairs), axes=axis) + + total = totals.sum(axis=axis, **kwargs) + + with np.errstate(divide="ignore", invalid="ignore"): + if np.issubdtype(total.dtype, np.complexfloating): + mu = divide(total, n) + inner_term = np.abs(divide(totals, ns) - mu) + else: + mu = divide(total, n, dtype=dtype) + inner_term = divide(totals, ns, dtype=dtype) - mu + + xs = [ + _moment_helper(Ms, ns, inner_term, o, sum, axis, kwargs) + for o in range(2, order + 1) + ] + M = np.stack(xs, axis=-1) + return {"total": total, "n": n, "M": M} + + +def moment_agg( + pairs, + order=2, + ddof=0, + dtype="f8", + sum=np.sum, + axis=None, + computing_meta=False, + **kwargs, +): + if not isinstance(pairs, list): + pairs = [pairs] + + kwargs["dtype"] = dtype + # To properly handle ndarrays, the original dimensions need to be kept for + # part of the calculation. + keepdim_kw = kwargs.copy() + keepdim_kw["keepdims"] = True + keepdim_kw["dtype"] = None + + ns = deepmap(lambda pair: pair["n"], pairs) if not computing_meta else pairs + ns = _concatenate2(ns, axes=axis) + n = ns.sum(axis=axis, **keepdim_kw) + + if computing_meta: + return n + + totals = _concatenate2(deepmap(lambda pair: pair["total"], pairs), axes=axis) + Ms = _concatenate2(deepmap(lambda pair: pair["M"], pairs), axes=axis) + + mu = divide(totals.sum(axis=axis, **keepdim_kw), n) + + with np.errstate(divide="ignore", invalid="ignore"): + if np.issubdtype(totals.dtype, np.complexfloating): + inner_term = np.abs(divide(totals, ns) - mu) + else: + inner_term = divide(totals, ns, dtype=dtype) - mu + + M = _moment_helper(Ms, ns, inner_term, order, sum, axis, kwargs) + + denominator = n.sum(axis=axis, **kwargs) - ddof + + # taking care of the edge case with empty or all-nans array with ddof > 0 + if isinstance(denominator, Number): + if denominator < 0: + denominator = np.nan + elif denominator is not np.ma.masked: + denominator[denominator < 0] = np.nan + + return divide(M, denominator, dtype=dtype) + + +def moment( + a, order, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None +): + """Calculate the nth centralized moment. + + Parameters + ---------- + a : Array + Data over which to compute moment + order : int + Order of the moment that is returned, must be >= 2. + axis : int, optional + Axis along which the central moment is computed. The default is to + compute the moment of the flattened array. + dtype : data-type, optional + Type to use in computing the moment. For arrays of integer type the + default is float64; for arrays of float types it is the same as the + array type. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the + result as dimensions with size one. With this option, the result + will broadcast correctly against the original array. + ddof : int, optional + "Delta Degrees of Freedom": the divisor used in the calculation is + N - ddof, where N represents the number of elements. By default + ddof is zero. + + Returns + ------- + moment : Array + + References + ---------- + .. [1] Pebay, Philippe (2008), "Formulas for Robust, One-Pass Parallel + Computation of Covariances and Arbitrary-Order Statistical Moments", + Technical Report SAND2008-6212, Sandia National Laboratories. + + """ + if not isinstance(order, Integral) or order < 0: + raise ValueError("Order must be an integer >= 0") + + if order < 2: + # reduced = a.sum(axis=axis) # get reduced shape and chunks + if order == 0: + raise NotImplementedError("need to implement ones") # TODO + # When order equals 0, the result is 1, by definition. + # return ones( + # reduced.shape, chunks=reduced.chunks, dtype="f8", meta=reduced._meta + # ) + # By definition the first order about the mean is 0. + raise NotImplementedError("need to implement zeros") # TODO + # return zeros( + # reduced.shape, chunks=reduced.chunks, dtype="f8", meta=reduced._meta + # ) + + if dtype is not None: + dt = dtype + else: + dt = getattr(np.var(np.ones(shape=(1,), dtype=a.dtype)), "dtype", object) + + implicit_complex_dtype = dtype is None and np.iscomplexobj(a) + + return reduction( + a, + partial( + moment_chunk, order=order, implicit_complex_dtype=implicit_complex_dtype + ), + partial(moment_agg, order=order, ddof=ddof), + axis=axis, + keepdims=keepdims, + dtype=dt, + split_every=split_every, + out=out, + concatenate=False, + combine=partial(moment_combine, order=order), + ) + + +@derived_from(np) +def var(a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None): + if dtype is not None: + dt = dtype + else: + dt = getattr(np.var(np.ones(shape=(1,), dtype=a.dtype)), "dtype", object) + + implicit_complex_dtype = dtype is None and np.iscomplexobj(a._meta) + + return reduction( + a, + partial(moment_chunk, implicit_complex_dtype=implicit_complex_dtype), + partial(moment_agg, ddof=ddof), + axis=axis, + keepdims=keepdims, + dtype=dt, + split_every=split_every, + combine=moment_combine, + name="var", + out=out, + concatenate=False, + ) + + +@derived_from(np) +def nanvar( + a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None +): + if dtype is not None: + dt = dtype + else: + dt = getattr(np.var(np.ones(shape=(1,), dtype=a.dtype)), "dtype", object) + + implicit_complex_dtype = dtype is None and np.iscomplexobj(a) + + return reduction( + a, + partial( + moment_chunk, + sum=chunk.nansum, + numel=nannumel, + implicit_complex_dtype=implicit_complex_dtype, + ), + partial(moment_agg, sum=np.nansum, ddof=ddof), + axis=axis, + keepdims=keepdims, + dtype=dt, + split_every=split_every, + combine=partial(moment_combine, sum=np.nansum), + out=out, + concatenate=False, + ) + + +def _sqrt(a): + o = np.sqrt(a) + if isinstance(o, np.ma.masked_array) and not o.shape and o.mask.all(): + return np.ma.masked + return o + + +def safe_sqrt(a): + """A version of sqrt that properly handles scalar masked arrays. + + To mimic ``np.ma`` reductions, we need to convert scalar masked arrays that + have an active mask to the ``np.ma.masked`` singleton. This is properly + handled automatically for reduction code, but not for ufuncs. We implement + a simple version here, since calling `np.ma.sqrt` everywhere is + significantly more expensive. + """ + if hasattr(a, "_elemwise"): + return a._elemwise(_sqrt, a) + return _sqrt(a) + + +@derived_from(np) +def std(a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None): + result = safe_sqrt( + var( + a, + axis=axis, + dtype=dtype, + keepdims=keepdims, + ddof=ddof, + split_every=split_every, + out=out, + ) + ) + if dtype and dtype != result.dtype: + result = result.astype(dtype) + return result + + +@derived_from(np) +def nanstd( + a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None +): + result = safe_sqrt( + nanvar( + a, + axis=axis, + dtype=dtype, + keepdims=keepdims, + ddof=ddof, + split_every=split_every, + out=out, + ) + ) + if dtype and dtype != result.dtype: + result = result.astype(dtype) + return result + + from dask_expr.array.blockwise import blockwise diff --git a/dask_expr/array/tests/test_array.py b/dask_expr/array/tests/test_array.py index acf23ccd..39bbe378 100644 --- a/dask_expr/array/tests/test_array.py +++ b/dask_expr/array/tests/test_array.py @@ -1,3 +1,5 @@ +import operator + import numpy as np import pytest from dask.array.utils import assert_eq @@ -145,9 +147,59 @@ def test_random(): assert_eq(x, x) -def test_reductions(): +@pytest.mark.parametrize( + "reduction", + ["sum", "mean", "var", "std", "any", "all", "prod", "min", "max"], +) +def test_reductions(reduction): a = np.random.random((10, 20)) b = da.from_array(a, chunks=(2, 5)) - assert_eq(b.sum(), a.sum()) - assert_eq(b.sum(axis=1), a.sum(axis=1)) + def func(x, **kwargs): + return getattr(x, reduction)(**kwargs) + + assert_eq(func(a), func(b)) + assert_eq(func(a, axis=1), func(b, axis=1)) + + +@pytest.mark.parametrize( + "reduction", + ["nanmean", "nansum"], +) +def test_reduction_functions(reduction): + a = np.random.random((10, 20)) + b = da.from_array(a, chunks=(2, 5)) + + def func(x, **kwargs): + if isinstance(x, np.ndarray): + return getattr(np, reduction)(x, **kwargs) + else: + return getattr(da, reduction)(x, **kwargs) + + func(b).chunks + + assert_eq(func(a), func(b)) + assert_eq(func(a, axis=1), func(b, axis=1)) + + +@pytest.mark.parametrize( + "ufunc", + [np.sqrt, np.sin, np.exp], +) +def test_ufunc(ufunc): + a = np.random.random((10, 20)) + b = da.from_array(a, chunks=(2, 5)) + assert_eq(ufunc(a), ufunc(b)) + + +@pytest.mark.parametrize( + "op", + [operator.add, operator.sub, operator.pow, operator.floordiv, operator.truediv], +) +def test_binop(op): + a = np.random.random((10, 20)) + b = np.random.random(20) + aa = da.from_array(a, chunks=(2, 5)) + bb = da.from_array(b, chunks=5) + + assert_eq(op(a, b), op(aa, bb)) From c386a2df4c47dd5f09e1771060b275e42ae0bd7b Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 18 Jan 2024 08:04:58 -0600 Subject: [PATCH 17/19] Avoid objects in pprint --- dask_expr/_core.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index adeb82cd..e4cbe18f 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -94,10 +94,11 @@ def _tree_repr_lines(self, indent=0, recursive=True): op = "" if repr(op) != repr(default): - if param: - header += f" {param}={repr(op)}" - else: - header += repr(op) + if "object at 0x" not in repr(op): + if param: + header += f" {param}={repr(op)}" + else: + header += repr(op) lines = [header] + lines lines = [" " * indent + line for line in lines] From 88b9d6c77f910fb939e0a6cea9a5da014e97444f Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 21 Jun 2024 15:40:06 +0200 Subject: [PATCH 18/19] rebase stuff --- dask_expr/array/core.py | 73 ++++++++++++++++++++++++++++++++--------- 1 file changed, 57 insertions(+), 16 deletions(-) diff --git a/dask_expr/array/core.py b/dask_expr/array/core.py index e6885338..de989c8d 100644 --- a/dask_expr/array/core.py +++ b/dask_expr/array/core.py @@ -1,13 +1,19 @@ +import functools import operator +from itertools import product from typing import Union import dask.array as da import numpy as np +from dask import istask +from dask.array.core import slices_from_chunks from dask.base import DaskMethodsMixin, named_schedulers -from dask.utils import cached_cumsum, cached_property +from dask.core import flatten +from dask.utils import SerializableLock, cached_cumsum, cached_property, key_split from toolz import reduce from dask_expr import _core as core +from dask_expr._util import _tokenize_deterministic T_IntOrNaN = Union[int, float] # Should be Union[int, Literal[np.nan]] @@ -24,7 +30,13 @@ def __dask_postcompute__(self): return da.core.finalize, () def __dask_postpersist__(self): - return FromGraph, (self._meta, self.chunks, self._name) + state = self.lower_completely() + return FromGraph, ( + state._meta, + state.chunks, + list(flatten(state.__dask_keys__())), + key_split(state._name), + ) def compute(self, **kwargs): return DaskMethodsMixin.compute(self.simplify(), **kwargs) @@ -425,7 +437,7 @@ class IO(Array): class FromArray(IO): - _parameters = ["array", "chunks"] + _parameters = ["array", "chunks", "lock"] @property def chunks(self): @@ -438,9 +450,28 @@ def _meta(self): return self.array[tuple(slice(0, 0) for _ in range(self.array.ndim))] def _layer(self): - dsk = da.core.graph_from_arraylike( - self.array, chunks=self.chunks, shape=self.array.shape, name=self._name - ) + lock = self.operand("lock") + if lock is True: + lock = SerializableLock() + + is_ndarray = type(self.array) in (np.ndarray, np.ma.core.MaskedArray) + is_single_block = all(len(c) == 1 for c in self.chunks) + # Always use the getter for h5py etc. Not using isinstance(x, np.ndarray) + # because np.matrix is a subclass of np.ndarray. + if is_ndarray and not is_single_block and not lock: + # eagerly slice numpy arrays to prevent memory blowup + # GH5367, GH5601 + slices = slices_from_chunks(self.chunks) + keys = product([self._name], *(range(len(bds)) for bds in self.chunks)) + values = [self.array[slc] for slc in slices] + dsk = dict(zip(keys, values)) + elif is_ndarray and is_single_block: + # No slicing needed + dsk = {(self._name,) + (0,) * self.array.ndim: self.array} + else: + dsk = da.core.graph_from_arraylike( + self.array, chunks=self.chunks, shape=self.array.shape, name=self._name + ) return dict(dsk) # this comes as a legacy HLG for now def __str__(self): @@ -448,26 +479,36 @@ def __str__(self): class FromGraph(Array): - _parameters = ["layer", "_meta", "chunks", "_name"] + _parameters = ["layer", "_meta", "chunks", "keys", "name_prefix"] @property def _meta(self): return self.operand("_meta") - @property - def chunks(self): - return self.operand("chunks") - - @property + @functools.cached_property def _name(self): - return self.operand("_name") + return ( + self.operand("name_prefix") + "-" + _tokenize_deterministic(*self.operands) + ) def _layer(self): - return dict(self.operand("layer")) + dsk = dict(self.operand("layer")) + # The name may not actually match the layers name therefore rewrite this + # using an alias + for k in self.operand("keys"): + if not isinstance(k, tuple): + raise TypeError(f"Expected tuple, got {type(k)}") + orig = dsk[k] + if not istask(orig): + del dsk[k] + dsk[(self._name, *k[1:])] = orig + else: + dsk[(self._name, *k[1:])] = k + return dsk -def from_array(x, chunks="auto"): - return FromArray(x, chunks) +def from_array(x, chunks="auto", lock=None): + return FromArray(x, chunks, lock=lock) from dask_expr.array.blockwise import Transpose, elemwise From 42d1863475fb4810a57dc2e359e823ee34d8a2ad Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 24 Jun 2024 22:43:59 +0300 Subject: [PATCH 19/19] Fix remaining tests --- ci/environment.yml | 1 + dask_expr/array/blockwise.py | 17 ++++++++++++----- dask_expr/array/tests/__init__.py | 0 dask_expr/array/tests/test_array.py | 1 - 4 files changed, 13 insertions(+), 6 deletions(-) create mode 100644 dask_expr/array/tests/__init__.py diff --git a/ci/environment.yml b/ci/environment.yml index c4923e84..5223ea98 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -9,6 +9,7 @@ dependencies: - pyarrow>=7 - pandas>=2 - pre-commit + - xarray - pip: - git+https://github.com/dask/distributed - git+https://github.com/dask/dask diff --git a/dask_expr/array/blockwise.py b/dask_expr/array/blockwise.py index 42eb740b..f704a792 100644 --- a/dask_expr/array/blockwise.py +++ b/dask_expr/array/blockwise.py @@ -1,3 +1,4 @@ +import functools import itertools import numbers from collections.abc import Iterable @@ -45,18 +46,24 @@ class Blockwise(Array): "kwargs": None, } - @property + @functools.cached_property def args(self): return self.operands[len(self._parameters) :] - @property + @functools.cached_property + def _meta_provided(self): + # We catch recursion errors if key starts with _meta, so define + # explicitly here + return self.operand("_meta_provided") + + @functools.cached_property def _meta(self): if self._meta_provided is not None: return self._meta_provided else: return compute_meta(self.func, self.dtype, *self.args[::2], **self.kwargs) - @property + @functools.cached_property def chunks(self): if self.align_arrays: chunkss, arrays = unify_chunks(*self.args) @@ -101,11 +108,11 @@ def chunks(self): chunks = tuple(chunks) return chunks - @property + @functools.cached_property def dtype(self): return self.operand("dtype") - @property + @functools.cached_property def _name(self): if "name" in self._parameters and self.operand("name"): return self.operand("name") diff --git a/dask_expr/array/tests/__init__.py b/dask_expr/array/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dask_expr/array/tests/test_array.py b/dask_expr/array/tests/test_array.py index 39bbe378..46a28400 100644 --- a/dask_expr/array/tests/test_array.py +++ b/dask_expr/array/tests/test_array.py @@ -124,7 +124,6 @@ def test_slicing_optimization(): assert b.T[5:].optimize()._name == b[:, 5:].T._name -@pytest.mark.xfail(reason="Blockwise specifies too much about dimension") def test_slicing_optimization_change_dimensionality(): a = np.random.random((10, 20)) b = da.from_array(a, chunks=(2, 5))