Skip to content

Commit

Permalink
remove ds from compute index in favor of using grad spacing directly
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici committed Jul 23, 2024
1 parent 4b21abc commit f75e4a2
Showing 1 changed file with 8 additions and 26 deletions.
34 changes: 8 additions & 26 deletions desc/compute/_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,6 @@ def _s(params, transforms, profiles, data, **kwargs):
return data


@register_compute_fun(
name="ds",
label="ds",
units="~",
units_long="None",
description="Spacing of curve parameter",
dim=1,
params=[],
transforms={"grid": []},
profiles=[],
coordinates="s",
data=[],
parameterization="desc.geometry.core.Curve",
)
def _ds(params, transforms, profiles, data, **kwargs):
data["ds"] = transforms["grid"].spacing[:, 2]
return data


@register_compute_fun(
name="X",
label="X",
Expand Down Expand Up @@ -980,17 +961,17 @@ def _torsion(params, transforms, profiles, data, **kwargs):
description="Length of the curve",
dim=0,
params=[],
transforms={},
transforms={"grid": []},
profiles=[],
coordinates="",
data=["ds", "x_s"],
data=["x_s"],
parameterization=["desc.geometry.core.Curve"],
)
def _length(params, transforms, profiles, data, **kwargs):
T = jnp.linalg.norm(data["x_s"], axis=-1)
# this is equivalent to jnp.trapz(T, s) for a closed curve,
# but also works if grid.endpoint is False
data["length"] = jnp.sum(T * data["ds"])
data["length"] = jnp.sum(T * transforms["grid"].spacing[:, 2])
return data


Expand All @@ -1002,10 +983,10 @@ def _length(params, transforms, profiles, data, **kwargs):
description="Length of the curve",
dim=0,
params=[],
transforms={},
transforms={"grid": []},
profiles=[],
coordinates="",
data=["ds", "x", "x_s"],
data=["x", "x_s"],
parameterization="desc.geometry.curve.SplineXYZCurve",
method="Interpolation type, Default 'cubic'. See SplineXYZCurve docs for options.",
)
Expand All @@ -1015,7 +996,8 @@ def _length_SplineXYZCurve(params, transforms, profiles, data, **kwargs):
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = rpz2xyz(coords)
# ensure curve is closed
# if it's already closed this doesn't add any length since ds will be zero
# if it's already closed this doesn't add any length since
# grid spacing will be zero at the duplicate point
coords = jnp.concatenate([coords, coords[:1]])
X = coords[:, 0]
Y = coords[:, 1]
Expand All @@ -1026,5 +1008,5 @@ def _length_SplineXYZCurve(params, transforms, profiles, data, **kwargs):
T = jnp.linalg.norm(data["x_s"], axis=-1)
# this is equivalent to jnp.trapz(T, s) for a closed curve
# but also works if grid.endpoint is False
data["length"] = jnp.sum(T * data["ds"])
data["length"] = jnp.sum(T * transforms["grid"].spacing[:, 2])

Check warning on line 1011 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L1011

Added line #L1011 was not covered by tests
return data

0 comments on commit f75e4a2

Please sign in to comment.