diff --git a/desc/plotting.py b/desc/plotting.py index 7644614d8c..a4be8cee7a 100644 --- a/desc/plotting.py +++ b/desc/plotting.py @@ -1774,7 +1774,9 @@ def plot_boundary(eq, phi=None, plot_axis=True, ax=None, return_data=False, **kw return fig, ax -def plot_boundaries(eqs, labels=None, phi=None, ax=None, return_data=False, **kwargs): +def plot_boundaries( + eqs, labels=None, phi=None, plot_axis=True, ax=None, return_data=False, **kwargs +): """Plot stellarator boundaries at multiple toroidal coordinates. Parameters @@ -1787,6 +1789,8 @@ def plot_boundaries(eqs, labels=None, phi=None, ax=None, return_data=False, **kw Values of phi to plot boundary surface at. If an integer, plot that many contours linearly spaced in [0,2pi). Default is 1 contour for axisymmetric equilibria or 4 for non-axisymmetry. + plot_axis : bool + Whether to plot the magnetic axis locations. Default is True. ax : matplotlib AxesSubplot, optional Axis to plot on. return_data : bool @@ -1808,6 +1812,8 @@ def plot_boundaries(eqs, labels=None, phi=None, ax=None, return_data=False, **kw * ``color``: list of colors to use for each Equilibrium * ``ls``: list of str, line styles to use for each Equilibrium * ``lw``: list of floats, line widths to use for each Equilibrium + * ``marker``: str, marker style to use for the axis plotted points + * ``size``: float, marker size to use for the axis plotted points Returns ------- @@ -1837,6 +1843,8 @@ def plot_boundaries(eqs, labels=None, phi=None, ax=None, return_data=False, **kw lw = kwargs.pop("lw", None) xlabel_fontsize = kwargs.pop("xlabel_fontsize", None) ylabel_fontsize = kwargs.pop("ylabel_fontsize", None) + marker = kwargs.pop("marker", "x") + size = kwargs.pop("size", 36) phi = (1 if eqs[-1].N == 0 else 4) if phi is None else phi if isinstance(phi, numbers.Integral): @@ -1866,7 +1874,11 @@ def plot_boundaries(eqs, labels=None, phi=None, ax=None, return_data=False, **kw plot_data["Z"] = [] for i in range(neq): - grid_kwargs = {"NFP": eqs[i].NFP, "theta": 100, "zeta": phi} + # don't plot axis for FourierRZToroidalSurface, since it's not defined. + plot_axis_i = plot_axis and eqs[i].L > 0 + rho = np.array([0.0, 1.0]) if plot_axis_i else np.array([1.0]) + + grid_kwargs = {"NFP": eqs[i].NFP, "theta": 100, "zeta": phi, "rho": rho} grid = _get_grid(**grid_kwargs) nr, nt, nz = grid.num_rho, grid.num_theta, grid.num_zeta grid = Grid( @@ -1893,6 +1905,11 @@ def plot_boundaries(eqs, labels=None, phi=None, ax=None, return_data=False, **kw (line,) = ax.plot( R[:, -1, j], Z[:, -1, j], color=colors[i], linestyle=ls[i], lw=lw[i] ) + if rho[0] == 0: + ax.scatter( + R[0, 0, j], Z[0, 0, j], color=colors[i], marker=marker, s=size + ) + if j == 0: line.set_label(labels[i]) diff --git a/tests/baseline/test_plot_boundaries.png b/tests/baseline/test_plot_boundaries.png index b05fa366fe..44d4792c1d 100644 Binary files a/tests/baseline/test_plot_boundaries.png and b/tests/baseline/test_plot_boundaries.png differ