Skip to content

Commit

Permalink
Remove jit method of objective, directly compile methods (#1043)
Browse files Browse the repository at this point in the history
Basically, right now we close over `self` when compiling the methods of
`ObjectiveFunction`. This means that JAX may bake all the attributes of
the objective (ie, transforms, profiles, fields, equilibrium etc) into
the compiled function which likely both slows down compilation and may
lead to extra memory usage.

This changes things so that instead we JIT the method directly, treating
`self` as just another argument. Doing this requires refactoring how the
derivatives get handled a bit (they are now only local to their
respective functions rather than being created separately, shouldn't be
any performance hit since creating the `Derivative` objects is basically
free).

Resolves #957 
Resolves #1191
  • Loading branch information
f0uriest authored Aug 22, 2024
2 parents 1c076fc + a719f8a commit bdc5de4
Show file tree
Hide file tree
Showing 6 changed files with 329 additions and 176 deletions.
5 changes: 4 additions & 1 deletion desc/io/optimizable_io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Functions and methods for saving and loading equilibria and other objects."""

import copy
import functools
import os
import pickle
import pydoc
Expand Down Expand Up @@ -86,7 +87,9 @@ def _unjittable(x):
return any([_unjittable(y) for y in x.values()])
if hasattr(x, "dtype") and np.ndim(x) == 0:
return np.issubdtype(x.dtype, np.bool_) or np.issubdtype(x.dtype, np.int_)
return isinstance(x, (str, types.FunctionType, bool, int, np.int_))
return isinstance(
x, (str, types.FunctionType, functools.partial, bool, int, np.int_)
)


def _make_hashable(x):
Expand Down
11 changes: 7 additions & 4 deletions desc/objectives/linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def update_target(self, thing):
assert len(new_target) == len(self.target)
self.target = new_target
self._target_from_user = self.target # in case the Objective is re-built
if self._use_jit:
self.jit()
if not self._use_jit:
self._unjit()

def _parse_target_from_user(
self, target_from_user, default_target, default_bounds, idx
Expand Down Expand Up @@ -232,8 +232,8 @@ def update_target(self, thing):
"""
self.target = self.compute(thing.params_dict)
if self._use_jit:
self.jit()
if not self._use_jit:
self._unjit()


class BoundaryRSelfConsistency(_Objective):
Expand Down Expand Up @@ -3184,6 +3184,7 @@ class FixNearAxisR(_FixedObjective):
"""

_static_attrs = ["_nae_eq"]
_target_arg = "R_lmn"
_fixed = False # not "diagonal", since its fixing a sum
_units = "(m)"
Expand Down Expand Up @@ -3320,6 +3321,7 @@ class FixNearAxisZ(_FixedObjective):
"""

_static_attrs = ["_nae_eq"]
_target_arg = "Z_lmn"
_fixed = False # not "diagonal", since its fixing a sum
_units = "(m)"
Expand Down Expand Up @@ -3462,6 +3464,7 @@ class FixNearAxisLambda(_FixedObjective):
"""

_static_attrs = ["_nae_eq"]
_target_arg = "L_lmn"
_fixed = False # not "diagonal", since its fixing a sum
_units = "(dimensionless)"
Expand Down
Loading

0 comments on commit bdc5de4

Please sign in to comment.