Skip to content

Commit

Permalink
Ensure derivs are ints
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Mar 2, 2024
1 parent 53015de commit 88bc24e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
10 changes: 5 additions & 5 deletions desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,14 +332,14 @@ def get_transforms(keys, obj, grid, jitable=False, **kwargs):
derivs = get_derivs(keys, obj, has_axis=grid.axis.size)
transforms = {"grid": grid}
for c in derivs.keys():
if hasattr(obj, c + "_basis"):
if hasattr(obj, c + "_basis"): # regular stuff like R, Z, lambda etc.
basis = getattr(obj, c + "_basis")
# first check if we already have a transform with a compatible basis
for transform in transforms.values():
if basis.eq(getattr(transform, "basis", None)):
ders = np.unique(
np.vstack([derivs[c], transform.derivatives]), axis=0
)
).astype(int)
# don't build until we know all the derivs we need
transform.change_derivatives(ders, build=False)
c_transform = transform
Expand All @@ -353,7 +353,7 @@ def get_transforms(keys, obj, grid, jitable=False, **kwargs):
method=method,
)
transforms[c] = c_transform
elif c == "B":
elif c == "B": # for fitting Boozer harmonics
transforms["B"] = Transform(
grid,
DoubleFourierSeries(
Expand All @@ -367,7 +367,7 @@ def get_transforms(keys, obj, grid, jitable=False, **kwargs):
build_pinv=True,
method=method,
)
elif c == "w":
elif c == "w": # for fitting Boozer toroidal stream function
transforms["w"] = Transform(
grid,
DoubleFourierSeries(
Expand All @@ -381,7 +381,7 @@ def get_transforms(keys, obj, grid, jitable=False, **kwargs):
build_pinv=True,
method=method,
)
elif c not in transforms:
elif c not in transforms: # possible other stuff lumped in with transforms
transforms[c] = getattr(obj, c)

# now build them
Expand Down
7 changes: 4 additions & 3 deletions desc/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def __init__(
self._method = method
# assign according to logic in setter function
self.method = method
self._matrices = self._get_matrices()
if build:
self.build()
if build_pinv:
Expand Down Expand Up @@ -374,6 +373,8 @@ def build(self):
self._built = True
return

self._matrices = self._get_matrices()

if self.method == "direct1":
for d in self.derivatives:
self.matrices["direct1"][d[0]][d[1]][d[2]] = self.basis.evaluate(
Expand All @@ -389,7 +390,7 @@ def build(self):
if self.method in ["fft", "direct2"]:
temp_d = np.hstack(
[self.derivatives[:, :2], np.zeros((len(self.derivatives), 1))]
)
).astype(int)
temp_modes = np.hstack([self.lm_modes, np.zeros((self.num_lm_modes, 1))])
for d in temp_d:
self.matrices["fft"][d[0]][d[1]] = self.basis.evaluate(
Expand All @@ -398,7 +399,7 @@ def build(self):
if self.method == "direct2":
temp_d = np.hstack(
[np.zeros((len(self.derivatives), 2)), self.derivatives[:, 2:]]
)
).astype(int)
temp_modes = np.hstack(
[np.zeros((self.num_n_modes, 2)), self.n_modes[:, np.newaxis]]
)
Expand Down

0 comments on commit 88bc24e

Please sign in to comment.