Skip to content

Commit

Permalink
changing the umbiliccurve API to directly take in phi and A
Browse files Browse the repository at this point in the history
  • Loading branch information
rahulgaur104 committed Sep 5, 2024
1 parent 6fdff0b commit a459f4e
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 7 deletions.
4 changes: 3 additions & 1 deletion desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@
)

if use_jax: # noqa: C901 - FIXME: simplify this, define globally and then assign?
from jax import custom_jvp, jit, vmap
#from jax import custom_jvp, jit, vmap
from jax import custom_jvp, vmap
jit = lambda func, *args, **kwargs: func

imap = jax.lax.map
from jax.experimental.ode import odeint
Expand Down
1 change: 1 addition & 0 deletions desc/compute/_umbiliccurve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import pdb
from .data_index import register_compute_fun


Expand Down
12 changes: 7 additions & 5 deletions desc/geometry/umbiliccurve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Classes for parameterized 3D umbilic space curves."""

import os
import pdb
import numpy as np

from desc.backend import jnp, put
Expand Down Expand Up @@ -183,15 +185,15 @@ def from_values(cls, coords, N=10, NFP=1, NFP_umbilic_factor=1, name="", sym=Fal
New representation of the curve parameterized by Fourier series for Z.
"""
theta = coords[:, 0]
phi = coords[:, 1]
A = NFP_umbilic_factor * theta - phi
phi = coords[:, 0]
A = coords[:, 1]

grid = LinearGrid(zeta=phi, NFP=1, sym=sym)
basis = FourierSeries(N=N, NFP=NFP, sym=sym)
grid = LinearGrid(zeta=phi, NFP=1, NFP_umbilic_factor=1, sym=sym)
basis = FourierSeries(N=N, NFP=1, sym=sym)
transform = Transform(grid, basis, build_pinv=True)
A_n = transform.fit(A)

#pdb.set_trace()
return FourierUmbilicCurve(
A_n=A_n,
NFP=NFP,
Expand Down
3 changes: 2 additions & 1 deletion desc/objectives/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,9 +1700,10 @@ def compute(self, params_1, params_2=None, constants=None):
curve_phi = curve_grid.nodes[:, 2]

theta_points = (
self._curve.NFP * curve_phi + curve_A
-self._curve.NFP * curve_phi + curve_A
) / self._curve.NFP_umbilic_factor


umbilic_edge_grid = Grid(
jnp.array([jnp.ones_like(theta_points), theta_points, curve_phi]).T,
jitable=True,
Expand Down

0 comments on commit a459f4e

Please sign in to comment.