From 88bc24eae3e2795b3a0f0ec4b1b74fe8e2be1dc2 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Fri, 1 Mar 2024 20:08:19 -0500 Subject: [PATCH] Ensure derivs are ints --- desc/compute/utils.py | 10 +++++----- desc/transform.py | 7 ++++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/desc/compute/utils.py b/desc/compute/utils.py index 89ad8f5c19..ec1fb823eb 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -332,14 +332,14 @@ def get_transforms(keys, obj, grid, jitable=False, **kwargs): derivs = get_derivs(keys, obj, has_axis=grid.axis.size) transforms = {"grid": grid} for c in derivs.keys(): - if hasattr(obj, c + "_basis"): + if hasattr(obj, c + "_basis"): # regular stuff like R, Z, lambda etc. basis = getattr(obj, c + "_basis") # first check if we already have a transform with a compatible basis for transform in transforms.values(): if basis.eq(getattr(transform, "basis", None)): ders = np.unique( np.vstack([derivs[c], transform.derivatives]), axis=0 - ) + ).astype(int) # don't build until we know all the derivs we need transform.change_derivatives(ders, build=False) c_transform = transform @@ -353,7 +353,7 @@ def get_transforms(keys, obj, grid, jitable=False, **kwargs): method=method, ) transforms[c] = c_transform - elif c == "B": + elif c == "B": # for fitting Boozer harmonics transforms["B"] = Transform( grid, DoubleFourierSeries( @@ -367,7 +367,7 @@ def get_transforms(keys, obj, grid, jitable=False, **kwargs): build_pinv=True, method=method, ) - elif c == "w": + elif c == "w": # for fitting Boozer toroidal stream function transforms["w"] = Transform( grid, DoubleFourierSeries( @@ -381,7 +381,7 @@ def get_transforms(keys, obj, grid, jitable=False, **kwargs): build_pinv=True, method=method, ) - elif c not in transforms: + elif c not in transforms: # possible other stuff lumped in with transforms transforms[c] = getattr(obj, c) # now build them diff --git a/desc/transform.py b/desc/transform.py index 7ff2e2bfb2..5ee180d0fb 100644 --- a/desc/transform.py +++ b/desc/transform.py @@ -82,7 +82,6 @@ def __init__( self._method = method # assign according to logic in setter function self.method = method - self._matrices = self._get_matrices() if build: self.build() if build_pinv: @@ -374,6 +373,8 @@ def build(self): self._built = True return + self._matrices = self._get_matrices() + if self.method == "direct1": for d in self.derivatives: self.matrices["direct1"][d[0]][d[1]][d[2]] = self.basis.evaluate( @@ -389,7 +390,7 @@ def build(self): if self.method in ["fft", "direct2"]: temp_d = np.hstack( [self.derivatives[:, :2], np.zeros((len(self.derivatives), 1))] - ) + ).astype(int) temp_modes = np.hstack([self.lm_modes, np.zeros((self.num_lm_modes, 1))]) for d in temp_d: self.matrices["fft"][d[0]][d[1]] = self.basis.evaluate( @@ -398,7 +399,7 @@ def build(self): if self.method == "direct2": temp_d = np.hstack( [np.zeros((len(self.derivatives), 2)), self.derivatives[:, 2:]] - ) + ).astype(int) temp_modes = np.hstack( [np.zeros((self.num_n_modes, 2)), self.n_modes[:, np.newaxis]] )