From 1af452302afdc8558773b503f621c08859565539 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sun, 24 Nov 2024 15:34:11 -0600 Subject: [PATCH] Hack dataclass'd function_interface to avoid breaking Firedrake --- loopy/kernel/function_interface.py | 56 +++++++++++++++++++++++++++--- loopy/library/random123.py | 3 +- 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py index 33dfd73f2..40d9969bf 100644 --- a/loopy/kernel/function_interface.py +++ b/loopy/kernel/function_interface.py @@ -23,8 +23,8 @@ THE SOFTWARE. """ from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence -from dataclasses import dataclass, replace +from collections.abc import Collection, Mapping, Sequence +from dataclasses import dataclass, fields, replace from typing import TYPE_CHECKING, Any, Callable, FrozenSet, TypeVar from warnings import warn @@ -304,7 +304,9 @@ def get_kw_pos_association(kernel): # {{{ template class -@dataclass(frozen=True, init=False) +# not frozen for Firedrake compatibility +# not eq to avoid having __hash__ set to None in subclasses +@dataclass(init=False, eq=False) class InKernelCallable(ABC): """ An abstract interface to define a callable encountered in a kernel. @@ -368,9 +370,51 @@ def __init__(self, def name(self) -> str: raise NotImplementedError() + # {{{ hackery to avoid breaking Firedrake + + def _all_attrs(self) -> Collection[str]: + dc_attrs = { + fld.name for fld in fields(self) + } + legacy_fields: Collection[str] = getattr(self, "fields", []) + return dc_attrs | set(legacy_fields) + def copy(self, **kwargs: Any) -> Self: + present_kwargs = { + name: getattr(self, name) + for name in self._all_attrs() + } + kwargs = { + **present_kwargs, + **kwargs, + } + return replace(self, **kwargs) + def update_persistent_hash(self, key_hash, key_builder) -> None: + for field_name in self._all_attrs(): + key_builder.rec(key_hash, getattr(self, field_name)) + + def __eq__(self, other: object): + if type(self) is not type(other): + return False + + for f in self._all_attrs(): + if getattr(self, f) != getattr(other, f): + return False + + return True + + def __hash__(self): + import hashlib + + from loopy.tools import LoopyKeyBuilder + key_hash = hashlib.sha256() + self.update_persistent_hash(key_hash, LoopyKeyBuilder()) + return hash(key_hash.digest()) + + # }}} + def with_types(self, arg_id_to_dtype, clbl_inf_ctx): """ :arg arg_id_to_type: a mapping from argument identifiers (integers for @@ -521,7 +565,8 @@ def is_type_specialized(self): # {{{ scalar callable -@dataclass(frozen=True, init=False) +# not frozen, not eq for Firedrake compatibility +@dataclass(init=False, eq=False) class ScalarCallable(InKernelCallable): """ An abstract interface to a scalar callable encountered in a kernel. @@ -699,7 +744,8 @@ def is_type_specialized(self): # {{{ callable kernel -@dataclass(frozen=True, init=False) +# not frozen, not eq for Firedrake compatibility +@dataclass(init=False, eq=False) class CallableKernel(InKernelCallable): """ Records information about a callee kernel. Also provides interface through diff --git a/loopy/library/random123.py b/loopy/library/random123.py index 329770e05..cde0b093a 100644 --- a/loopy/library/random123.py +++ b/loopy/library/random123.py @@ -176,7 +176,8 @@ def full_name(self) -> str: # }}} -@dataclass(frozen=True, init=False) +# not frozen, not eq for Firedrake compatibility +@dataclass(init=False, eq=False) class Random123Callable(ScalarCallable): """ Records information about for the random123 functions.