diff --git a/arraycontext/context.py b/arraycontext/context.py index 398f8aa3..66965ee0 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -492,7 +492,8 @@ def clone(self: SelfType) -> SelfType: "setup-only" array context "leaks" into the application. """ - def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + def compile(self, f: Callable[..., Any], + single_version_only: bool = False) -> Callable[..., Any]: """Compiles *f* for repeated use on this array context. *f* is expected to be a `pure function `__ performing an array computation. @@ -508,6 +509,8 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: it may be called only once (or a few times). :arg f: the function executing the computation. + :arg single_version_only: If *True*, raise an error if *f* is compiled + more than once (due to different input argument types). :return: a function with the same signature as *f*. """ return f diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index e3ce52a7..98eef8c2 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -632,9 +632,11 @@ def call_loopy(self, program, **kwargs): return call_loopy(program, processed_kwargs, entrypoint) - def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + def compile(self, f: Callable[..., Any], + single_version_only: bool = False) -> Callable[..., Any]: from .compile import LazilyPyOpenCLCompilingFunctionCaller - return LazilyPyOpenCLCompilingFunctionCaller(self, f) + return LazilyPyOpenCLCompilingFunctionCaller(self, + f, single_version_only) def transform_dag(self, dag: pytato.DictOfNamedArrays ) -> pytato.DictOfNamedArrays: @@ -846,9 +848,10 @@ def _thaw(ary): self._rec_map_container(_thaw, array, self._frozen_array_types), actx=self) - def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + def compile(self, f: Callable[..., Any], + single_version_only: bool = False) -> Callable[..., Any]: from .compile import LazilyJAXCompilingFunctionCaller - return LazilyJAXCompilingFunctionCaller(self, f) + return LazilyJAXCompilingFunctionCaller(self, f, single_version_only) def tag(self, tags: ToTagSetConvertible, array): def _tag(ary): diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 952761bf..69d8bbe8 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -260,6 +260,7 @@ class BaseLazilyCompilingFunctionCaller: actx: _BasePytatoArrayContext f: Callable[..., Any] + single_version_only: bool program_cache: dict[Mapping[tuple[Hashable, ...], AbstractInputDescriptor], "CompiledFunction"] = field(default_factory=lambda: {}) @@ -325,7 +326,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: try: compiled_f = self.program_cache[arg_id_to_descr] except KeyError: - pass + if self.single_version_only and self.program_cache: + raise ValueError( + f"Function '{self.f.__name__}' to be compiled " + "was already compiled previously with different arguments.") else: return compiled_f(arg_id_to_arg)