Skip to content

Commit

Permalink
Allow shared transforms between R, Z, lambda
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Mar 1, 2024
1 parent 4698f41 commit 53015de
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,26 @@ def get_transforms(keys, obj, grid, jitable=False, **kwargs):
transforms = {"grid": grid}
for c in derivs.keys():
if hasattr(obj, c + "_basis"):
transforms[c] = Transform(
grid,
getattr(obj, c + "_basis"),
derivs=derivs[c],
build=True,
method=method,
)
basis = getattr(obj, c + "_basis")
# first check if we already have a transform with a compatible basis
for transform in transforms.values():
if basis.eq(getattr(transform, "basis", None)):
ders = np.unique(
np.vstack([derivs[c], transform.derivatives]), axis=0
)
# don't build until we know all the derivs we need
transform.change_derivatives(ders, build=False)
c_transform = transform
break
else: # if we didn't exit the loop early
c_transform = Transform(
grid,
basis,
derivs=derivs[c],
build=False,
method=method,
)
transforms[c] = c_transform
elif c == "B":
transforms["B"] = Transform(
grid,
Expand All @@ -350,7 +363,7 @@ def get_transforms(keys, obj, grid, jitable=False, **kwargs):
sym=obj.R_basis.sym,
),
derivs=derivs["B"],
build=True,
build=False,
build_pinv=True,
method=method,
)
Expand All @@ -364,13 +377,18 @@ def get_transforms(keys, obj, grid, jitable=False, **kwargs):
sym=obj.Z_basis.sym,
),
derivs=derivs["w"],
build=True,
build=False,
build_pinv=True,
method=method,
)
elif c not in transforms:
transforms[c] = getattr(obj, c)

# now build them
for t in transforms.values():
if hasattr(t, "build"):
t.build()

return transforms


Expand Down

0 comments on commit 53015de

Please sign in to comment.