Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

single_version_only for actx.compile #228

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@ def clone(self: SelfType) -> SelfType:
"setup-only" array context "leaks" into the application.
"""

def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
def compile(self, f: Callable[..., Any],
single_version_only: bool = False) -> Callable[..., Any]:
inducer marked this conversation as resolved.
Show resolved Hide resolved
"""Compiles *f* for repeated use on this array context. *f* is expected
to be a `pure function <https://en.wikipedia.org/wiki/Pure_function>`__
performing an array computation.
Expand All @@ -508,6 +509,8 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
it may be called only once (or a few times).

:arg f: the function executing the computation.
:arg single_version_only: If *True*, raise an error if *f* is compiled
more than once (due to different input argument types).
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • This needs to explain quite a bit more context, such as by explaining what triggers a recompile.
  • I'm not sure I like "version" as a noun here.

:return: a function with the same signature as *f*.
"""
return f
Expand Down
11 changes: 7 additions & 4 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,11 @@ def call_loopy(self, program, **kwargs):

return call_loopy(program, processed_kwargs, entrypoint)

def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
def compile(self, f: Callable[..., Any],
single_version_only: bool = False) -> Callable[..., Any]:
from .compile import LazilyPyOpenCLCompilingFunctionCaller
return LazilyPyOpenCLCompilingFunctionCaller(self, f)
return LazilyPyOpenCLCompilingFunctionCaller(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also handle the other sub-classes of BaseLazilyCompilingFunctionCaller?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 5ed25a5

f, single_version_only)

def transform_dag(self, dag: pytato.DictOfNamedArrays
) -> pytato.DictOfNamedArrays:
Expand Down Expand Up @@ -846,9 +848,10 @@ def _thaw(ary):
self._rec_map_container(_thaw, array, self._frozen_array_types),
actx=self)

def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
def compile(self, f: Callable[..., Any],
single_version_only: bool = False) -> Callable[..., Any]:
from .compile import LazilyJAXCompilingFunctionCaller
return LazilyJAXCompilingFunctionCaller(self, f)
return LazilyJAXCompilingFunctionCaller(self, f, single_version_only)

def tag(self, tags: ToTagSetConvertible, array):
def _tag(ary):
Expand Down
6 changes: 5 additions & 1 deletion arraycontext/impl/pytato/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ class BaseLazilyCompilingFunctionCaller:

actx: _BasePytatoArrayContext
f: Callable[..., Any]
single_version_only: bool
program_cache: dict[Mapping[tuple[Hashable, ...], AbstractInputDescriptor],
"CompiledFunction"] = field(default_factory=lambda: {})

Expand Down Expand Up @@ -325,7 +326,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
try:
compiled_f = self.program_cache[arg_id_to_descr]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we save ourselves having to find arg_id_to_descr?

except KeyError:
pass
if self.single_version_only and self.program_cache:
raise ValueError(
f"Function '{self.f.__name__}' to be compiled "
"was already compiled previously with different arguments.")
else:
return compiled_f(arg_id_to_arg)

Expand Down
Loading