From 5078e656abec2198f600b83aabfa26b0461bd897 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Fri, 14 Nov 2025 23:52:56 -0500 Subject: [PATCH 01/20] make a differentiable version of _constant_offset_surface --- desc/geometry/surface.py | 115 +++++++++++++++++- desc/objectives/_geometry.py | 187 +++++++++++++++++++++++++++++- desc/objectives/objective_funs.py | 2 + 3 files changed, 301 insertions(+), 3 deletions(-) diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 204d99382c..846cdd7571 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -771,8 +771,8 @@ def fun_jax(zeta_hat, theta, zeta): else: zetas = vecroot(grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2]) - zetas = np.asarray(zetas) - nodes = np.vstack((np.ones_like(grid.nodes[:, 1]), grid.nodes[:, 1], zetas)).T + zetas = jnp.asarray(zetas) + nodes = jnp.vstack((jnp.ones_like(grid.nodes[:, 1]), grid.nodes[:, 1], zetas)).T n, x, x_offsets = n_and_r_jax(nodes) data = {} @@ -1146,3 +1146,114 @@ def get_axis(self): data = self.compute(["R", "Z"], grid=grid) axis = FourierRZCurve(R_n=data["R"][0], Z_n=data["Z"][0], sym=self.sym) return axis + + +def _constant_offset_surface( + base_surface, offset, grid, M, N, full_output=False, params=None +): + """Create a FourierRZSurface with constant offset from the base surface (self). + + Implementation of algorithm described in Appendix B of + "An improved current potential method for fast computation of + stellarator coil shapes", Landreman (2017) + https://iopscience.iop.org/article/10.1088/1741-4326/aa57d4 + + NOTE: Must have the toroidal angle as the cylindrical toroidal angle + in order for this algorithm to work properly + + NOTE: this function lacks the checks of the constant_offset_surface + so that it is jittable/differentiable + + Parameters + ---------- + base_surface : FourierRZToroidalSurface + Surface from which the constant offset surface will be found. + offset : float + constant offset (in m) of the desired surface from the input surface + offset will be in the normal direction to the surface. + grid : Grid, optional + Grid object of the points on the given surface to evaluate the + offset points at, from which the offset surface will be created by fitting + offset points with the basis defined by the given M and N. + If None, defaults to a LinearGrid with M and N and NFP equal to the + base_surface.M and base_surface.N and base_surface.NFP + M : int, optional + Poloidal resolution of the basis used to fit the offset points + to create the resulting constant offset surface, by default equal + to base_surface.M + N : int, optional + Toroidal resolution of the basis used to fit the offset points + to create the resulting constant offset surface, by default equal + to base_surface.N + full_output : bool, optional + If True, also return a dict of useful data about the surfaces and a + tuple where the first element is the residual from + the root finding and the second is the number of iterations. + + Returns + ------- + offset_surface : FourierRZToroidalSurface + FourierRZToroidalSurface, created from fitting points offset from the input + surface by the given constant offset. + data : dict + dictionary containing the following data, in the cylindrical basis: + ``n`` : (``grid.num_nodes`` x 3) array of the unit surface normal on + the base_surface evaluated at the input ``grid`` + ``x`` : (``grid.num_nodes`` x 3) array of coordinates on + the base_surface evaluated at the input ``grid`` + ``x_offset_surface`` : (``grid.num_nodes`` x 3) array of the + coordinates on the offset surface, corresponding to the + ``x`` points on the base_surface (i.e. the points to which the + offset surface was fit) + info : tuple + 2 element tuple containing residuals and number of iterations + for each point. Only returned if ``full_output`` is True + + """ + if params is None: + params = base_surface.params_dict + + def n_and_r_jax(nodes): + data = base_surface.compute( + ["X", "Y", "Z", "n_rho"], + grid=Grid(nodes, jitable=True, sort=False), + method="jitable", + params=params, + ) + + phi = nodes[:, 2] + re = jnp.vstack([data["X"], data["Y"], data["Z"]]).T + n = data["n_rho"] + n = rpz2xyz_vec(n, phi=phi) + r_offset = re + offset * n + return n, re, r_offset + + def fun_jax(zeta_hat, theta, zeta): + nodes = jnp.vstack((jnp.ones_like(theta), theta, zeta_hat)).T + n, r, r_offset = n_and_r_jax(nodes) + return jnp.arctan(r_offset[0, 1] / r_offset[0, 0]) - zeta + + vecroot = jit( + vmap( + lambda x0, *p: root_scalar( + fun_jax, x0, jac=None, args=p, full_output=full_output + ) + ) + ) + if full_output: + zetas, (res, niter) = vecroot( + grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2] + ) + else: + zetas = vecroot(grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2]) + + zetas = jnp.asarray(zetas) + nodes = jnp.vstack((jnp.ones_like(grid.nodes[:, 1]), grid.nodes[:, 1], zetas)).T + n, x, x_offsets = n_and_r_jax(nodes) + + data = {} + data["n"] = xyz2rpz_vec(n, phi=nodes[:, 1]) + data["x"] = xyz2rpz(x) + data["x_offset_surface"] = xyz2rpz(x_offsets) + + return data["x_offset_surface"] diff --git a/desc/objectives/_geometry.py b/desc/objectives/_geometry.py index 52cdb3d632..c3a93c7f14 100644 --- a/desc/objectives/_geometry.py +++ b/desc/objectives/_geometry.py @@ -5,7 +5,8 @@ from desc.backend import jnp, vmap from desc.compute import get_profiles, get_transforms from desc.compute.utils import _compute as compute_fun -from desc.grid import LinearGrid, QuadratureGrid +from desc.geometry.surface import _constant_offset_surface +from desc.grid import Grid, LinearGrid, QuadratureGrid from desc.utils import ( Timer, copy_rpz_periods, @@ -452,6 +453,190 @@ def compute(self, params, constants=None): return data["V"] +class VolumeOffset(_Objective): + """Compute offset surface volume. + + Parameters + ---------- + eq : Equilibrium or FourierRZToroidalSurface + Equilibrium or FourierRZToroidalSurface that + will be optimized to satisfy the Objective. + grid : Grid, optional + Collocation grid containing the nodes to evaluate at. Defaults to + ``QuadratureGrid(eq.L_grid, eq.M_grid, eq.N_grid)`` for ``Equilibrium`` + or ``LinearGrid(M=2*eq.M, N=2*eq.N)`` for ``FourierRZToroidalSurface``. + offset : float, optional + surface offset in meters. + + """ + + __doc__ = __doc__.rstrip() + collect_docs( + target_default="``target=1``.", + bounds_default="``target=1``.", + loss_detail=" Note: Has no effect for this objective.", + ) + + _scalar = True + _units = "(m^3)" + _print_value_fmt = "Offset surface volume: " + + def __init__( + self, + eq, + target=None, + bounds=None, + weight=1, + normalize=True, + normalize_target=True, + loss_function=None, + deriv_mode="auto", + grid=None, + name="volume", + jac_chunk_size=None, + ): + if target is None and bounds is None: + target = 1 + self._grid = grid + super().__init__( + things=eq, + target=target, + bounds=bounds, + weight=weight, + normalize=normalize, + normalize_target=normalize_target, + loss_function=loss_function, + deriv_mode=deriv_mode, + name=name, + jac_chunk_size=jac_chunk_size, + ) + + def build(self, use_jit=True, verbose=1): + """Build constant arrays. + + Parameters + ---------- + use_jit : bool, optional + Whether to just-in-time compile the objective and derivatives. + verbose : int, optional + Level of output. + + """ + eq = self.things[0] + if self._grid is None: + # if not an Equilibrium, is a Surface, + # has no radial resolution so just need + # the surface points + grid = LinearGrid( + rho=1.0, + M=eq.M * 2, + N=eq.N * 2, + NFP=eq.NFP, + ) + else: + grid = self._grid + + self._dim_f = 1 + self._data_keys = ["V"] + + timer = Timer() + if verbose > 0: + print("Precomputing transforms") + timer.start("Precomputing transforms") + + profiles = get_profiles(self._data_keys, obj=eq.surface, grid=grid) + transforms = get_transforms(self._data_keys, obj=eq.surface, grid=grid) + self._constants = { + "transforms": transforms, + "profiles": profiles, + } + + timer.stop("Precomputing transforms") + if verbose > 1: + timer.disp("Precomputing transforms") + + if self._normalize: + scales = compute_scaling_factors(eq) + self._normalization = scales["V"] + from desc.objectives import BoundaryRSelfConsistency, BoundaryZSelfConsistency + + obj_bdryR = BoundaryRSelfConsistency(eq=self.things[0]) + obj_bdryZ = BoundaryZSelfConsistency(eq=self.things[0]) + obj_bdryR.build() + obj_bdryZ.build() + self.constants["A_Rlmn_to_Rb"] = obj_bdryR._A + self.constants["A_Zlmn_to_Zb"] = obj_bdryZ._A + super().build(use_jit=use_jit, verbose=verbose) + + def compute(self, params, constants=None): + """Compute offset surface volume. + + Parameters + ---------- + params : dict + Dictionary of equilibrium or surface degrees of freedom, + eg Equilibrium.params_dict + constants : dict + Dictionary of constant data, eg transforms, profiles etc. Defaults to + self.constants + + Returns + ------- + V : float + Plasma volume (m^3). + + """ + if constants is None: + constants = self.constants + surf = self.things[0].surface + # assign eq Rb and Zb to surf dict so JAX knows + # to trace the derivs wrt Rb etc when surf_params is passed + # find eq Rb Zb by computing R_lmn at rho=1 so that AD knows how + # to connect deriv back to R_lmn + surf_params = surf.params_dict + surf_params["R_lmn"] = jnp.dot(self.constants["A_Rlmn_to_Rb"], params["R_lmn"]) + surf_params["Z_lmn"] = jnp.dot(self.constants["A_Zlmn_to_Zb"], params["Z_lmn"]) + + x_offset_surf = _constant_offset_surface( + surf, + offset=0.2, + M=surf.M, + N=surf.N, + grid=constants["transforms"]["grid"], + params=surf_params, + ) + offset_zetas = x_offset_surf[:, 1] + offset_rtz_nodes = jnp.vstack( + [ + jnp.ones_like(offset_zetas), + constants["transforms"]["grid"].nodes[:, 1], + offset_zetas, + ] + ).T + # make transform to fit + t = get_transforms( + obj=surf, + keys=["R", "Z"], + grid=Grid(offset_rtz_nodes, jitable=True), + jitable=True, + build_pinv=True, + ) + t["R"].build_pinv() + t["Z"].build_pinv() + + surf_params2 = surf_params.copy() + surf_params2["R_lmn"] = t["R"].fit(x_offset_surf[:, 0]) + surf_params2["Z_lmn"] = t["Z"].fit(x_offset_surf[:, 2]) + + data = compute_fun( + surf, + self._data_keys, + params=surf_params2, + transforms=constants["transforms"], + profiles=constants["profiles"], + ) + return data["V"] + + class PlasmaVesselDistance(_Objective): """Target the distance between the plasma and a surrounding surface. diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index b4f1a4d2f2..ae4ec84a49 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -247,6 +247,7 @@ class ObjectiveFunction(IOAble): "_name", "_things_per_objective_idx", "_use_jit", + "_static_attrs", ] def __init__( @@ -1102,6 +1103,7 @@ class _Objective(IOAble, ABC): "_print_value_fmt", "_scalar", "_units", + "_static_attrs", ] def __init__( From 9c0c65d1469b3a155e52c8c8eff209132dd940e4 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Fri, 14 Nov 2025 23:54:43 -0500 Subject: [PATCH 02/20] add offset kwarg --- desc/objectives/_geometry.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/desc/objectives/_geometry.py b/desc/objectives/_geometry.py index c3a93c7f14..02e1b7e4c3 100644 --- a/desc/objectives/_geometry.py +++ b/desc/objectives/_geometry.py @@ -493,10 +493,12 @@ def __init__( grid=None, name="volume", jac_chunk_size=None, + offset=0.1, ): if target is None and bounds is None: target = 1 self._grid = grid + self._offset = offset super().__init__( things=eq, target=target, @@ -598,7 +600,7 @@ def compute(self, params, constants=None): x_offset_surf = _constant_offset_surface( surf, - offset=0.2, + offset=self._offset, M=surf.M, N=surf.N, grid=constants["transforms"]["grid"], From 3919bb3b2cdefd813f2fc2686e523cc56e437f90 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Mon, 2 Feb 2026 15:59:16 -0500 Subject: [PATCH 03/20] change constant offset surface to reuse the jitable version's code, and to make jitable version also perform the fit --- desc/geometry/surface.py | 148 +++++++++++++++++++---------------- desc/objectives/_geometry.py | 56 +++++++------ 2 files changed, 109 insertions(+), 95 deletions(-) diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 6ccb2d393f..365bf03cc1 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -666,7 +666,7 @@ def from_shape_parameters( def constant_offset_surface( self, offset, grid=None, M=None, N=None, full_output=False ): - """Create a FourierRZSurface with constant offset from the base surface (self). + """Create a new FourierRZToroidalSurface with constant offset from self. Implementation of algorithm described in Appendix B of "An improved current potential method for fast computation of @@ -676,6 +676,10 @@ def constant_offset_surface( NOTE: Must have the toroidal angle as the cylindrical toroidal angle in order for this algorithm to work properly + NOTE: if one wants to use this inside of an optimization, one should + use the private method _constant_offset_surface directly, and refer to + the documentation in PR #2016 for more details. + Parameters ---------- base_surface : FourierRZToroidalSurface @@ -717,6 +721,8 @@ def constant_offset_surface( coordinates on the offset surface, corresponding to the ``x`` points on the base_surface (i.e. the points to which the offset surface was fit) + as well as the DoubleFourierSeries bases used to fit R and Z. + Only returned if ``full_output`` is True info : tuple 2 element tuple containing residuals and number of iterations for each point. Only returned if ``full_output`` is True @@ -739,56 +745,23 @@ def constant_offset_surface( M = base_surface.M if M is None else int(M) N = base_surface.N if N is None else int(N) - def n_and_r_jax(nodes): - data = base_surface.compute( - ["X", "Y", "Z", "n_rho"], - grid=Grid(nodes, jitable=True, sort=False), - method="jitable", - ) - - phi = nodes[:, 2] - re = jnp.vstack([data["X"], data["Y"], data["Z"]]).T - n = data["n_rho"] - n = rpz2xyz_vec(n, phi=phi) - r_offset = re + offset * n - return n, re, r_offset - - def fun_jax(zeta_hat, theta, zeta): - nodes = jnp.vstack((jnp.ones_like(theta), theta, zeta_hat)).T - n, r, r_offset = n_and_r_jax(nodes) - return jnp.arctan(r_offset[0, 1] / r_offset[0, 0]) - zeta - - vecroot = jit( - vmap( - lambda x0, *p: root_scalar( - fun_jax, x0, jac=None, args=p, full_output=full_output - ) - ) - ) - if full_output: - zetas, (res, niter) = vecroot( - grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2] - ) - else: - zetas = vecroot(grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2]) - - zetas = jnp.asarray(zetas) - nodes = jnp.vstack((jnp.ones_like(grid.nodes[:, 1]), grid.nodes[:, 1], zetas)).T - n, x, x_offsets = n_and_r_jax(nodes) - - data = {} - data["n"] = xyz2rpz_vec(n, phi=nodes[:, 1]) - data["x"] = xyz2rpz(x) - data["x_offset_surface"] = xyz2rpz(x_offsets) - - offset_surface = FourierRZToroidalSurface.from_values( - data["x_offset_surface"], - theta=nodes[:, 1], + R_lmn, Z_lmn, data, (res, niter) = _constant_offset_surface( + base_surface, + offset, + grid=grid, M=M, N=N, - NFP=base_surface.NFP, - sym=base_surface.sym, ) + + offset_surface = FourierRZToroidalSurface( + R_lmn, + Z_lmn, + data["R_basis"].modes[:, 1:], + data["Z_basis"].modes[:, 1:], + base_surface.NFP, + base_surface.sym, + ) + if full_output: return offset_surface, data, (res, niter) else: @@ -1208,9 +1181,16 @@ def _get_ess_scale(self, alpha=1.2, order=np.inf, min_value=1e-7): def _constant_offset_surface( - base_surface, offset, grid, M, N, full_output=False, params=None + base_surface, + offset, + grid, + M=None, + N=None, + R_basis=None, + Z_basis=None, + params=None, ): - """Create a FourierRZSurface with constant offset from the base surface (self). + """Create a FourierRZToroidalSurface with constant offset from the base surface. Implementation of algorithm described in Appendix B of "An improved current potential method for fast computation of @@ -1239,21 +1219,25 @@ def _constant_offset_surface( M : int, optional Poloidal resolution of the basis used to fit the offset points to create the resulting constant offset surface, by default equal - to base_surface.M + to base_surface.M. If basis is given, this is ignored. N : int, optional Toroidal resolution of the basis used to fit the offset points to create the resulting constant offset surface, by default equal - to base_surface.N - full_output : bool, optional - If True, also return a dict of useful data about the surfaces and a - tuple where the first element is the residual from - the root finding and the second is the number of iterations. + to base_surface.N. If basis is given, this is ignored. + R_basis, Z_basis: DoubleFourierSeries, optional + Basis to use to fit the offset surface's R and Z, respectivelu. If None, + new bases will be created using the given M and N. + params : dict, optional + dictionary of parameters to use when computing data from the base_surface. + If None, uses base_surface.params_dict, however the resulting computation + will not be differentiable with respect to the base_surface parameters + (since the JAX AD inside of an objective traces the params dictionaries + that are passedto their compute methods) Returns ------- - offset_surface : FourierRZToroidalSurface - FourierRZToroidalSurface, created from fitting points offset from the input - surface by the given constant offset. + R_lmn, Z_lmn : array-like + coefficients describing the offset surface geometry data : dict dictionary containing the following data, in the cylindrical basis: ``n`` : (``grid.num_nodes`` x 3) array of the unit surface normal on @@ -1264,9 +1248,10 @@ def _constant_offset_surface( coordinates on the offset surface, corresponding to the ``x`` points on the base_surface (i.e. the points to which the offset surface was fit) + as well as the DoubleFourierSeries bases used to fit R and Z. info : tuple 2 element tuple containing residuals and number of iterations - for each point. Only returned if ``full_output`` is True + for each point. """ if params is None: @@ -1294,17 +1279,23 @@ def fun_jax(zeta_hat, theta, zeta): vecroot = jit( vmap( - lambda x0, *p: root_scalar( - fun_jax, x0, jac=None, args=p, full_output=full_output - ) + lambda x0, *p: root_scalar(fun_jax, x0, jac=None, args=p, full_output=True) ) ) - if full_output: - zetas, (res, niter) = vecroot( - grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2] + zetas, (res, niter) = vecroot(grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2]) + + if R_basis is None: + M = base_surface.M if M is None else int(M) + N = base_surface.N if N is None else int(N) + R_basis = DoubleFourierSeries( + M=M, N=N, NFP=base_surface.NFP, sym=base_surface.R_basis.sym + ) + if Z_basis is None: + M = base_surface.M if M is None else int(M) + N = base_surface.N if N is None else int(N) + Z_basis = DoubleFourierSeries( + M=M, N=N, NFP=base_surface.NFP, sym=base_surface.Z_basis.sym ) - else: - zetas = vecroot(grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2]) zetas = jnp.asarray(zetas) nodes = jnp.vstack((jnp.ones_like(grid.nodes[:, 1]), grid.nodes[:, 1], zetas)).T @@ -1315,4 +1306,23 @@ def fun_jax(zeta_hat, theta, zeta): data["x"] = xyz2rpz(x) data["x_offset_surface"] = xyz2rpz(x_offsets) - return data["x_offset_surface"] + offset_zetas = data["x_offset_surface"][:, 1] + offset_rtz_nodes = jnp.vstack( + [ + jnp.ones_like(offset_zetas), + grid.nodes[:, 1], + offset_zetas, + ] + ).T + # make transform to fit + grid_offset = Grid(offset_rtz_nodes, jitable=True) + t_R = Transform(grid=grid_offset, basis=R_basis, method="jitable", build_pinv=True) + t_Z = Transform(grid=grid_offset, basis=Z_basis, method="jitable", build_pinv=True) + + R_lmn = t_R.fit(data["x_offset_surface"][:, 0]) + Z_lmn = t_Z.fit(data["x_offset_surface"][:, 2]) + + data["R_basis"] = R_basis + data["Z_basis"] = Z_basis + + return R_lmn, Z_lmn, data, (res, niter) diff --git a/desc/objectives/_geometry.py b/desc/objectives/_geometry.py index 02e1b7e4c3..9e6ea57610 100644 --- a/desc/objectives/_geometry.py +++ b/desc/objectives/_geometry.py @@ -3,10 +3,11 @@ import numpy as np from desc.backend import jnp, vmap +from desc.basis import DoubleFourierSeries from desc.compute import get_profiles, get_transforms from desc.compute.utils import _compute as compute_fun from desc.geometry.surface import _constant_offset_surface -from desc.grid import Grid, LinearGrid, QuadratureGrid +from desc.grid import LinearGrid, QuadratureGrid from desc.utils import ( Timer, copy_rpz_periods, @@ -494,11 +495,18 @@ def __init__( name="volume", jac_chunk_size=None, offset=0.1, + offset_surface_M=None, + offset_surface_N=None, ): if target is None and bounds is None: target = 1 self._grid = grid self._offset = offset + if offset_surface_M is None: + self._offset_surface_M = eq.surface.M + if offset_surface_N is None: + self._offset_surface_N = eq.surface.N + super().__init__( things=eq, target=target, @@ -550,7 +558,18 @@ def build(self, use_jit=True, verbose=1): self._constants = { "transforms": transforms, "profiles": profiles, + "offset_surface_basis": DoubleFourierSeries( + M=self._offset_surface_M, + N=self._offset_surface_N, + NFP=eq.surface.NFP, + sym=eq.surface.sym, + ), } + self._constants["offset_transforms"] = get_transforms( + self._data_keys, + obj=self._constants["offset_surface_basis"], + grid=grid, + ) timer.stop("Precomputing transforms") if verbose > 1: @@ -590,7 +609,7 @@ def compute(self, params, constants=None): if constants is None: constants = self.constants surf = self.things[0].surface - # assign eq Rb and Zb to surf dict so JAX knows + # assign eq Rb and Zb to the surface's params dict so JAX knows # to trace the derivs wrt Rb etc when surf_params is passed # find eq Rb Zb by computing R_lmn at rho=1 so that AD knows how # to connect deriv back to R_lmn @@ -598,7 +617,8 @@ def compute(self, params, constants=None): surf_params["R_lmn"] = jnp.dot(self.constants["A_Rlmn_to_Rb"], params["R_lmn"]) surf_params["Z_lmn"] = jnp.dot(self.constants["A_Zlmn_to_Zb"], params["Z_lmn"]) - x_offset_surf = _constant_offset_surface( + # find offset surface using the surf_params + R_lmn_offset, Z_lmn_offset = _constant_offset_surface( surf, offset=self._offset, M=surf.M, @@ -606,34 +626,18 @@ def compute(self, params, constants=None): grid=constants["transforms"]["grid"], params=surf_params, ) - offset_zetas = x_offset_surf[:, 1] - offset_rtz_nodes = jnp.vstack( - [ - jnp.ones_like(offset_zetas), - constants["transforms"]["grid"].nodes[:, 1], - offset_zetas, - ] - ).T - # make transform to fit - t = get_transforms( - obj=surf, - keys=["R", "Z"], - grid=Grid(offset_rtz_nodes, jitable=True), - jitable=True, - build_pinv=True, - ) - t["R"].build_pinv() - t["Z"].build_pinv() - + # make a new params dict for the offset surface + # and assign the offset R_lmn and Z_lmn we just computed, + # so that AD knows to trace the derivs wrt the original eq R_lmn Z_lmn surf_params2 = surf_params.copy() - surf_params2["R_lmn"] = t["R"].fit(x_offset_surf[:, 0]) - surf_params2["Z_lmn"] = t["Z"].fit(x_offset_surf[:, 2]) - + surf_params2["R_lmn"] = R_lmn_offset + surf_params2["Z_lmn"] = Z_lmn_offset + # finally, compute Volume of the offset surface data = compute_fun( surf, self._data_keys, params=surf_params2, - transforms=constants["transforms"], + transforms=constants["offset_transforms"], profiles=constants["profiles"], ) return data["V"] From 9896d54ef3fbe8a575a7e21cfb7dd3e817a52239 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Mon, 2 Feb 2026 16:24:15 -0500 Subject: [PATCH 04/20] remove VolumeOffset --- desc/objectives/_geometry.py | 191 ----------------------------------- 1 file changed, 191 deletions(-) diff --git a/desc/objectives/_geometry.py b/desc/objectives/_geometry.py index 9e6ea57610..52cdb3d632 100644 --- a/desc/objectives/_geometry.py +++ b/desc/objectives/_geometry.py @@ -3,10 +3,8 @@ import numpy as np from desc.backend import jnp, vmap -from desc.basis import DoubleFourierSeries from desc.compute import get_profiles, get_transforms from desc.compute.utils import _compute as compute_fun -from desc.geometry.surface import _constant_offset_surface from desc.grid import LinearGrid, QuadratureGrid from desc.utils import ( Timer, @@ -454,195 +452,6 @@ def compute(self, params, constants=None): return data["V"] -class VolumeOffset(_Objective): - """Compute offset surface volume. - - Parameters - ---------- - eq : Equilibrium or FourierRZToroidalSurface - Equilibrium or FourierRZToroidalSurface that - will be optimized to satisfy the Objective. - grid : Grid, optional - Collocation grid containing the nodes to evaluate at. Defaults to - ``QuadratureGrid(eq.L_grid, eq.M_grid, eq.N_grid)`` for ``Equilibrium`` - or ``LinearGrid(M=2*eq.M, N=2*eq.N)`` for ``FourierRZToroidalSurface``. - offset : float, optional - surface offset in meters. - - """ - - __doc__ = __doc__.rstrip() + collect_docs( - target_default="``target=1``.", - bounds_default="``target=1``.", - loss_detail=" Note: Has no effect for this objective.", - ) - - _scalar = True - _units = "(m^3)" - _print_value_fmt = "Offset surface volume: " - - def __init__( - self, - eq, - target=None, - bounds=None, - weight=1, - normalize=True, - normalize_target=True, - loss_function=None, - deriv_mode="auto", - grid=None, - name="volume", - jac_chunk_size=None, - offset=0.1, - offset_surface_M=None, - offset_surface_N=None, - ): - if target is None and bounds is None: - target = 1 - self._grid = grid - self._offset = offset - if offset_surface_M is None: - self._offset_surface_M = eq.surface.M - if offset_surface_N is None: - self._offset_surface_N = eq.surface.N - - super().__init__( - things=eq, - target=target, - bounds=bounds, - weight=weight, - normalize=normalize, - normalize_target=normalize_target, - loss_function=loss_function, - deriv_mode=deriv_mode, - name=name, - jac_chunk_size=jac_chunk_size, - ) - - def build(self, use_jit=True, verbose=1): - """Build constant arrays. - - Parameters - ---------- - use_jit : bool, optional - Whether to just-in-time compile the objective and derivatives. - verbose : int, optional - Level of output. - - """ - eq = self.things[0] - if self._grid is None: - # if not an Equilibrium, is a Surface, - # has no radial resolution so just need - # the surface points - grid = LinearGrid( - rho=1.0, - M=eq.M * 2, - N=eq.N * 2, - NFP=eq.NFP, - ) - else: - grid = self._grid - - self._dim_f = 1 - self._data_keys = ["V"] - - timer = Timer() - if verbose > 0: - print("Precomputing transforms") - timer.start("Precomputing transforms") - - profiles = get_profiles(self._data_keys, obj=eq.surface, grid=grid) - transforms = get_transforms(self._data_keys, obj=eq.surface, grid=grid) - self._constants = { - "transforms": transforms, - "profiles": profiles, - "offset_surface_basis": DoubleFourierSeries( - M=self._offset_surface_M, - N=self._offset_surface_N, - NFP=eq.surface.NFP, - sym=eq.surface.sym, - ), - } - self._constants["offset_transforms"] = get_transforms( - self._data_keys, - obj=self._constants["offset_surface_basis"], - grid=grid, - ) - - timer.stop("Precomputing transforms") - if verbose > 1: - timer.disp("Precomputing transforms") - - if self._normalize: - scales = compute_scaling_factors(eq) - self._normalization = scales["V"] - from desc.objectives import BoundaryRSelfConsistency, BoundaryZSelfConsistency - - obj_bdryR = BoundaryRSelfConsistency(eq=self.things[0]) - obj_bdryZ = BoundaryZSelfConsistency(eq=self.things[0]) - obj_bdryR.build() - obj_bdryZ.build() - self.constants["A_Rlmn_to_Rb"] = obj_bdryR._A - self.constants["A_Zlmn_to_Zb"] = obj_bdryZ._A - super().build(use_jit=use_jit, verbose=verbose) - - def compute(self, params, constants=None): - """Compute offset surface volume. - - Parameters - ---------- - params : dict - Dictionary of equilibrium or surface degrees of freedom, - eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - V : float - Plasma volume (m^3). - - """ - if constants is None: - constants = self.constants - surf = self.things[0].surface - # assign eq Rb and Zb to the surface's params dict so JAX knows - # to trace the derivs wrt Rb etc when surf_params is passed - # find eq Rb Zb by computing R_lmn at rho=1 so that AD knows how - # to connect deriv back to R_lmn - surf_params = surf.params_dict - surf_params["R_lmn"] = jnp.dot(self.constants["A_Rlmn_to_Rb"], params["R_lmn"]) - surf_params["Z_lmn"] = jnp.dot(self.constants["A_Zlmn_to_Zb"], params["Z_lmn"]) - - # find offset surface using the surf_params - R_lmn_offset, Z_lmn_offset = _constant_offset_surface( - surf, - offset=self._offset, - M=surf.M, - N=surf.N, - grid=constants["transforms"]["grid"], - params=surf_params, - ) - # make a new params dict for the offset surface - # and assign the offset R_lmn and Z_lmn we just computed, - # so that AD knows to trace the derivs wrt the original eq R_lmn Z_lmn - surf_params2 = surf_params.copy() - surf_params2["R_lmn"] = R_lmn_offset - surf_params2["Z_lmn"] = Z_lmn_offset - # finally, compute Volume of the offset surface - data = compute_fun( - surf, - self._data_keys, - params=surf_params2, - transforms=constants["offset_transforms"], - profiles=constants["profiles"], - ) - return data["V"] - - class PlasmaVesselDistance(_Objective): """Target the distance between the plasma and a surrounding surface. From af3c12a05c969b126b859de9f0275717502be81c Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Tue, 3 Feb 2026 11:57:03 -0500 Subject: [PATCH 05/20] allow transform to be built and passed to constant offset surface --- desc/geometry/surface.py | 89 +++++++++++++++------------------------- 1 file changed, 33 insertions(+), 56 deletions(-) diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 365bf03cc1..00130dfcc9 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -16,6 +16,7 @@ vmap, ) from desc.basis import DoubleFourierSeries, ZernikePolynomial +from desc.compute import get_transforms from desc.grid import Grid, LinearGrid from desc.io import InputReader from desc.optimizable import optimizable_parameter @@ -688,7 +689,7 @@ def constant_offset_surface( constant offset (in m) of the desired surface from the input surface offset will be in the normal direction to the surface. grid : Grid, optional - Grid object of the points on the given surface to evaluate the + Grid object of the points on the offset surface to evaluate the offset points at, from which the offset surface will be created by fitting offset points with the basis defined by the given M and N. If None, defaults to a LinearGrid with M and N and NFP equal to the @@ -721,7 +722,7 @@ def constant_offset_surface( coordinates on the offset surface, corresponding to the ``x`` points on the base_surface (i.e. the points to which the offset surface was fit) - as well as the DoubleFourierSeries bases used to fit R and Z. + as well as the transforms bases used to fit R and Z. Only returned if ``full_output`` is True info : tuple 2 element tuple containing residuals and number of iterations @@ -731,7 +732,7 @@ def constant_offset_surface( M = check_nonnegint(M, "M") N = check_nonnegint(N, "N") - base_surface = self + base_surface = self.copy() if grid is None: grid = LinearGrid( M=base_surface.M * 2, @@ -744,20 +745,19 @@ def constant_offset_surface( ), "base_surface must be a FourierRZToroidalSurface!" M = base_surface.M if M is None else int(M) N = base_surface.N if N is None else int(N) + base_surface.change_resolution(M=M, N=N) R_lmn, Z_lmn, data, (res, niter) = _constant_offset_surface( base_surface, offset, grid=grid, - M=M, - N=N, ) offset_surface = FourierRZToroidalSurface( R_lmn, Z_lmn, - data["R_basis"].modes[:, 1:], - data["Z_basis"].modes[:, 1:], + data["transforms"]["R"].basis.modes[:, 1:], + data["transforms"]["Z"].basis.modes[:, 1:], base_surface.NFP, base_surface.sym, ) @@ -1184,10 +1184,7 @@ def _constant_offset_surface( base_surface, offset, grid, - M=None, - N=None, - R_basis=None, - Z_basis=None, + transforms=None, params=None, ): """Create a FourierRZToroidalSurface with constant offset from the base surface. @@ -1211,22 +1208,16 @@ def _constant_offset_surface( constant offset (in m) of the desired surface from the input surface offset will be in the normal direction to the surface. grid : Grid, optional - Grid object of the points on the given surface to evaluate the + Grid object of the points on the offset surface to evaluate the offset points at, from which the offset surface will be created by fitting offset points with the basis defined by the given M and N. If None, defaults to a LinearGrid with M and N and NFP equal to the base_surface.M and base_surface.N and base_surface.NFP - M : int, optional - Poloidal resolution of the basis used to fit the offset points - to create the resulting constant offset surface, by default equal - to base_surface.M. If basis is given, this is ignored. - N : int, optional - Toroidal resolution of the basis used to fit the offset points - to create the resulting constant offset surface, by default equal - to base_surface.N. If basis is given, this is ignored. - R_basis, Z_basis: DoubleFourierSeries, optional - Basis to use to fit the offset surface's R and Z, respectivelu. If None, - new bases will be created using the given M and N. + transforms: dict, optional + Transforms to use to fit the offset surface's R and Z, respectively. If None, + new transforms will be created using the given surface's M and N. + If given, should contain the keys ["R"] and ["Z"], with the pinv matrices + already built, and the corresponding grid should match the input grid. params : dict, optional dictionary of parameters to use when computing data from the base_surface. If None, uses base_surface.params_dict, however the resulting computation @@ -1248,7 +1239,7 @@ def _constant_offset_surface( coordinates on the offset surface, corresponding to the ``x`` points on the base_surface (i.e. the points to which the offset surface was fit) - as well as the DoubleFourierSeries bases used to fit R and Z. + as well as the transforms used to fit R and Z. info : tuple 2 element tuple containing residuals and number of iterations for each point. @@ -1279,24 +1270,13 @@ def fun_jax(zeta_hat, theta, zeta): vecroot = jit( vmap( - lambda x0, *p: root_scalar(fun_jax, x0, jac=None, args=p, full_output=True) + lambda x0, *p: root_scalar( + fun_jax, x0, jac=None, args=p, full_output=True, tol=1e-16 + ) ) ) zetas, (res, niter) = vecroot(grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2]) - if R_basis is None: - M = base_surface.M if M is None else int(M) - N = base_surface.N if N is None else int(N) - R_basis = DoubleFourierSeries( - M=M, N=N, NFP=base_surface.NFP, sym=base_surface.R_basis.sym - ) - if Z_basis is None: - M = base_surface.M if M is None else int(M) - N = base_surface.N if N is None else int(N) - Z_basis = DoubleFourierSeries( - M=M, N=N, NFP=base_surface.NFP, sym=base_surface.Z_basis.sym - ) - zetas = jnp.asarray(zetas) nodes = jnp.vstack((jnp.ones_like(grid.nodes[:, 1]), grid.nodes[:, 1], zetas)).T n, x, x_offsets = n_and_r_jax(nodes) @@ -1306,23 +1286,20 @@ def fun_jax(zeta_hat, theta, zeta): data["x"] = xyz2rpz(x) data["x_offset_surface"] = xyz2rpz(x_offsets) - offset_zetas = data["x_offset_surface"][:, 1] - offset_rtz_nodes = jnp.vstack( - [ - jnp.ones_like(offset_zetas), - grid.nodes[:, 1], - offset_zetas, - ] - ).T - # make transform to fit - grid_offset = Grid(offset_rtz_nodes, jitable=True) - t_R = Transform(grid=grid_offset, basis=R_basis, method="jitable", build_pinv=True) - t_Z = Transform(grid=grid_offset, basis=Z_basis, method="jitable", build_pinv=True) - - R_lmn = t_R.fit(data["x_offset_surface"][:, 0]) - Z_lmn = t_Z.fit(data["x_offset_surface"][:, 2]) - - data["R_basis"] = R_basis - data["Z_basis"] = Z_basis + if transforms is None: + # NOTE: we are assuming here that the rootfind was successful for every point, + # so that the zeta=arctan(y/x) of the offset surface point are the same as + # the grid nodes' zeta values. If this is not the case, the fitting + # will be incorrect. + transforms = get_transforms( + obj=base_surface, keys=["R", "Z"], grid=grid, jitable=True + ) + transforms["R"].build_pinv() + transforms["Z"].build_pinv() + + R_lmn = transforms["R"].fit(data["x_offset_surface"][:, 0]) + Z_lmn = transforms["Z"].fit(data["x_offset_surface"][:, 2]) + + data["transforms"] = transforms return R_lmn, Z_lmn, data, (res, niter) From 747eae73f20a945bea9f5d0108dc98541f2000be Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Tue, 3 Feb 2026 12:00:06 -0500 Subject: [PATCH 06/20] reduce rootfind tol --- desc/geometry/surface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 00130dfcc9..2adb8010a1 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -1271,7 +1271,7 @@ def fun_jax(zeta_hat, theta, zeta): vecroot = jit( vmap( lambda x0, *p: root_scalar( - fun_jax, x0, jac=None, args=p, full_output=True, tol=1e-16 + fun_jax, x0, jac=None, args=p, full_output=True, tol=1e-12 ) ) ) From f70b5446bfbee1d569c069352b5407e9ba7d9271 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Tue, 3 Feb 2026 12:25:08 -0500 Subject: [PATCH 07/20] add test --- tests/test_surfaces.py | 113 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/tests/test_surfaces.py b/tests/test_surfaces.py index 4adf058c81..9e91e4ec49 100644 --- a/tests/test_surfaces.py +++ b/tests/test_surfaces.py @@ -4,9 +4,12 @@ import pytest import desc.examples +from desc.backend import jax, jit +from desc.compute import get_transforms from desc.equilibrium import Equilibrium from desc.examples import get from desc.geometry import FourierRZToroidalSurface, ZernikeRZToroidalSection +from desc.geometry.surface import _constant_offset_surface from desc.grid import LinearGrid from desc.utils import rpz2xyz @@ -212,6 +215,116 @@ def test_constant_offset_surface_circle(self): err_msg=f"Failed test at comparison of {key}", ) + @pytest.mark.unit + def test_constant_offset_surface_circle_jax_transformable(self, capsys): + """Test constant offset algorithm is jax transformable.""" + s = FourierRZToroidalSurface() + grid = LinearGrid(M=3, N=2) + transforms = get_transforms(["R", "Z"], s, grid) + transforms["R"].build_pinv() + transforms["Z"].build_pinv() + offset = 1 + + def fun(params): + return _constant_offset_surface(s, offset, grid, transforms, params) + + # ensure is jitable + R_lmn, Z_lmn, data, _ = jit(fun)(s.params_dict) + + s_offset = FourierRZToroidalSurface( + R_lmn=R_lmn, + Z_lmn=Z_lmn, + M=s.M, + N=s.N, + NFP=s.NFP, + sym=s.sym, + modes_R=data["transforms"]["R"].basis.modes[:, 1:], + modes_Z=data["transforms"]["Z"].basis.modes[:, 1:], + ) + + r_offset_surf = data["x_offset_surface"] + r_surf = data["x"] + dists = np.linalg.norm(r_surf - r_offset_surf, axis=1) + np.testing.assert_allclose(dists, 1, atol=1e-16) + R00_offset_ind = s_offset.R_basis.get_idx(M=0, N=0) + R00_offset = s_offset.R_lmn[R00_offset_ind] + R10_offset_ind = s_offset.R_basis.get_idx(M=1, N=0) + R10_offset = s_offset.R_lmn[R10_offset_ind] + Zneg10_offset_ind = s_offset.Z_basis.get_idx(M=-1, N=0) + Zneg10_offset = s_offset.Z_lmn[Zneg10_offset_ind] + + np.testing.assert_allclose(R00_offset, 10) + np.testing.assert_allclose(R10_offset, 2) + np.testing.assert_allclose(Zneg10_offset, -2) + np.testing.assert_allclose( + np.delete( + s_offset.R_lmn, + np.array([R00_offset_ind, R10_offset_ind]), + ), + 0, + atol=9e-15, + ) + np.testing.assert_allclose( + np.delete( + s_offset.Z_lmn, + Zneg10_offset_ind, + ), + 0, + atol=9e-15, + ) + grid_compute = LinearGrid(M=10, N=10) + data = s.compute(["x", "e_theta", "e_zeta"], basis="rpz", grid=grid_compute) + data_offset = s_offset.compute( + ["x", "e_theta", "e_zeta"], basis="rpz", grid=grid_compute + ) + dists = np.linalg.norm(data["x"] - data_offset["x"], axis=1) + np.testing.assert_allclose(dists, 1, atol=1e-16) + correct_data_offset = { + "e_theta": np.vstack( + ( + -2 * np.sin(grid_compute.nodes[:, 1]), + np.zeros_like(grid_compute.nodes[:, 1]), + -2 * np.cos(grid_compute.nodes[:, 1]), + ) + ).T, + "e_zeta": np.vstack( + ( + np.zeros_like(grid_compute.nodes[:, 1]), + data_offset["x"][:, 0], + np.zeros_like(grid_compute.nodes[:, 1]), + ) + ).T, + } + for key in ["e_theta", "e_zeta"]: + np.testing.assert_allclose( + correct_data_offset[key], + data_offset[key], + atol=1e-4, + err_msg=f"Failed test at comparison of {key}", + ) + # make sure that the function is not recompiled + with jax.log_compiles(): + R_lmn, Z_lmn, data, _ = jit(fun)(s.params_dict) + + out = capsys.readouterr() + assert out.out == "" + + # check gradient is correct + # R00 of offset should change with R00 of original surface + grad_R00 = jax.grad(lambda params: fun(params)[0][s.R_basis.get_idx(M=0, N=0)])( + s.params_dict + ) + # check that the gradient is nonzero for the R00 component + assert np.any(np.abs(grad_R00["R_lmn"][s.R_basis.get_idx(M=0, N=0)]) > 1e-10) + # check that the gradient is zero otherwise + non_R00_indices = np.where(s.R_basis.modes.sum(axis=1) != 0)[0] + assert np.all(np.abs(grad_R00["R_lmn"][non_R00_indices]) < 1e-10) + + # check gradient is correct + np.testing.assert_allclose( + grad_R00["R_lmn"][s.R_basis.get_idx(M=0, N=0)], 1.0, atol=1e-10 + ) + @pytest.mark.slow @pytest.mark.unit def test_constant_offset_surface_rot_ellipse(self): From 573ad0f10aabf2e869540bb0d6e1fc58957c78e0 Mon Sep 17 00:00:00 2001 From: Dario Panici <37969854+dpanici@users.noreply.github.com> Date: Tue, 3 Feb 2026 12:25:50 -0500 Subject: [PATCH 08/20] Update desc/geometry/surface.py Co-authored-by: Yigit Gunsur Elmacioglu <102380275+YigitElma@users.noreply.github.com> --- desc/geometry/surface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 2adb8010a1..48f57dae1f 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -1223,7 +1223,7 @@ def _constant_offset_surface( If None, uses base_surface.params_dict, however the resulting computation will not be differentiable with respect to the base_surface parameters (since the JAX AD inside of an objective traces the params dictionaries - that are passedto their compute methods) + that are passed to their compute methods) Returns ------- From 7abd21d7ff8a332f9a6d64ea1a4673eb8299b84d Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Tue, 3 Feb 2026 12:28:39 -0500 Subject: [PATCH 09/20] fix incorrect n calc, though did not affect any subsequent calculations --- desc/geometry/surface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 2adb8010a1..2602193496 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -1282,7 +1282,7 @@ def fun_jax(zeta_hat, theta, zeta): n, x, x_offsets = n_and_r_jax(nodes) data = {} - data["n"] = xyz2rpz_vec(n, phi=nodes[:, 1]) + data["n"] = xyz2rpz_vec(n, phi=nodes[:, 2]) data["x"] = xyz2rpz(x) data["x_offset_surface"] = xyz2rpz(x_offsets) From 156c89cc33a177ced70219990e054a82d7672373 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Tue, 3 Feb 2026 14:03:33 -0500 Subject: [PATCH 10/20] use arctan2 and correct angle span inside of rootfind, this should make it more robust --- desc/geometry/surface.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 08c918cb60..1a61a68728 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -1266,7 +1266,10 @@ def n_and_r_jax(nodes): def fun_jax(zeta_hat, theta, zeta): nodes = jnp.vstack((jnp.ones_like(theta), theta, zeta_hat)).T n, r, r_offset = n_and_r_jax(nodes) - return jnp.arctan(r_offset[0, 1] / r_offset[0, 0]) - zeta + # add 2pi to the arctan2<0 so it matches our convention of + # zeta being btwn 0 and 2pi + zeta_offset = jnp.arctan2(r_offset[0, 1], r_offset[0, 0]) + return jnp.where(zeta_offset < 0, zeta_offset + 2 * np.pi, zeta_offset) - zeta vecroot = jit( vmap( @@ -1292,7 +1295,10 @@ def fun_jax(zeta_hat, theta, zeta): # the grid nodes' zeta values. If this is not the case, the fitting # will be incorrect. transforms = get_transforms( - obj=base_surface, keys=["R", "Z"], grid=grid, jitable=True + obj=base_surface, + keys=["R", "Z"], + grid=grid, + jitable=True, ) transforms["R"].build_pinv() transforms["Z"].build_pinv() From 7ce498c29a3f18fed560e387e5a0ea81dcd70a3a Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Tue, 3 Feb 2026 14:07:21 -0500 Subject: [PATCH 11/20] update docs --- desc/geometry/surface.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 1a61a68728..34e84cb0e6 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -692,8 +692,8 @@ def constant_offset_surface( Grid object of the points on the offset surface to evaluate the offset points at, from which the offset surface will be created by fitting offset points with the basis defined by the given M and N. - If None, defaults to a LinearGrid with M and N and NFP equal to the - base_surface.M and base_surface.N and base_surface.NFP + If None, defaults to a LinearGrid with M and N and NFP equal to twice the + base_surface.M and base_surface.N and NFP equal to base_surface.NFP M : int, optional Poloidal resolution of the basis used to fit the offset points to create the resulting constant offset surface, by default equal @@ -1211,8 +1211,8 @@ def _constant_offset_surface( Grid object of the points on the offset surface to evaluate the offset points at, from which the offset surface will be created by fitting offset points with the basis defined by the given M and N. - If None, defaults to a LinearGrid with M and N and NFP equal to the - base_surface.M and base_surface.N and base_surface.NFP + If None, defaults to a LinearGrid with M and N and NFP equal to twice the + base_surface.M and base_surface.N and NFP equal to base_surface.NFP transforms: dict, optional Transforms to use to fit the offset surface's R and Z, respectively. If None, new transforms will be created using the given surface's M and N. From 32b2879551b1b33af62658fee3f6b3007d7c6211 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Tue, 3 Feb 2026 14:08:28 -0500 Subject: [PATCH 12/20] make grid sym in tests with constant offset surface be true --- tests/conftest.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 64974de460..3de44d4119 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -353,7 +353,7 @@ def regcoil_helical_coils_scan(): offset=0.2, # desired offset M=16, # Poloidal resolution of desired offset surface N=12, # Toroidal resolution of desired offset surface - grid=LinearGrid(M=32, N=16, NFP=eq.NFP), + grid=LinearGrid(M=32, N=16, NFP=eq.NFP, sym=eq.sym), ) surface_current_field = FourierCurrentPotentialField.from_surface( surf_winding, M_Phi=8, N_Phi=8 @@ -381,7 +381,7 @@ def regcoil_modular_coils(): offset=0.2, # desired offset M=16, # Poloidal resolution of desired offset surface N=12, # Toroidal resolution of desired offset surface - grid=LinearGrid(M=32, N=16, NFP=eq.NFP), + grid=LinearGrid(M=32, N=16, NFP=eq.NFP, sym=eq.sym), ) M_Phi = 10 N_Phi = 10 @@ -417,7 +417,7 @@ def regcoil_windowpane_coils(): offset=0.2, # desired offset M=16, # Poloidal resolution of desired offset surface N=12, # Toroidal resolution of desired offset surface - grid=LinearGrid(M=32, N=16, NFP=eq.NFP), + grid=LinearGrid(M=32, N=16, NFP=eq.NFP, sym=eq.sym), ) M_Phi = 10 N_Phi = 10 @@ -457,7 +457,7 @@ def regcoil_PF_coils(): offset=0.2, # desired offset M=16, # Poloidal resolution of desired offset surface N=12, # Toroidal resolution of desired offset surface - grid=LinearGrid(M=32, N=16, NFP=eq.NFP), + grid=LinearGrid(M=32, N=16, NFP=eq.NFP, sym=eq.sym), ) M_Phi = 10 N_Phi = 10 From 16bd93edfeb6aa9ce9bf8c653e0220310a9ae88a Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Tue, 3 Feb 2026 14:18:33 -0500 Subject: [PATCH 13/20] fix test --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 3de44d4119..d23c9fbde5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -425,7 +425,7 @@ def regcoil_windowpane_coils(): N_egrid = 20 M_sgrid = 20 N_sgrid = 20 - lambda_regularization = 1e-18 + lambda_regularization = 1e-20 surface_current_field = FourierCurrentPotentialField.from_surface( surf_winding, M_Phi=M_Phi, N_Phi=N_Phi, sym_Phi="sin" From 3ca222cf168759b53aa0433e23f1ac420866b5b7 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Tue, 3 Feb 2026 14:23:58 -0500 Subject: [PATCH 14/20] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e59e02665e..21e17cf91b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,11 +22,14 @@ alternative to fourier continuation methods. - Adds ``"scipy-l-bfgs-b"`` optimizer option as a wrapper to scipy's ``"l-bfgs-b"`` method. - Adds ``check_intersection`` flag to ``desc.magnetic_fields.FourierCurrentPotentialField.to_Coilset``, to allow the choice of checking the resulting coilset for intersections or not. - Changes the import paths for ``desc.external`` to require reference to the sub-modules. +- Adds a differentiable utility for finding constant offset toroidal surfaces inside of optimizations. See [PR](https://github.com/PlasmaControl/DESC/pull/2016) for more details. + Bug Fixes - No longer uses the full Hessian to compute the scale when ``x_scale="auto"`` and using a scipy optimizer that approximates the hessian (e.g. if using ``"scipy-bfgs"``, no longer attempts the Hessian computation to get the x_scale). - ``SplineMagneticField.from_field()`` correctly uses the ``NFP`` input when given. Also adds this as a similar input option to ``MagneticField.save_mgrid()``. +- Fixes some bugs that hampered robustness of ``desc.geometry.FourierRZToroidalSurface.constant_offset_surface``, particularly when the given grid had stellarator symmetry or when NFP=1. Performance Improvements From 2e06b387a3072d2f26c0024f759253f8643321e2 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Tue, 3 Feb 2026 14:37:57 -0500 Subject: [PATCH 15/20] update tests again --- desc/geometry/surface.py | 2 ++ tests/conftest.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 34e84cb0e6..b4571c66e5 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -1298,6 +1298,8 @@ def fun_jax(zeta_hat, theta, zeta): obj=base_surface, keys=["R", "Z"], grid=grid, + # this is more robust than letting method become fft if + # jitable=False, and will also work within jitted functions jitable=True, ) transforms["R"].build_pinv() diff --git a/tests/conftest.py b/tests/conftest.py index d23c9fbde5..88c6eb4764 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -389,7 +389,7 @@ def regcoil_modular_coils(): N_egrid = 20 M_sgrid = 40 N_sgrid = 40 - lambda_regularization = 1e-18 + lambda_regularization = 1e-20 surface_current_field = FourierCurrentPotentialField.from_surface( surf_winding, M_Phi=M_Phi, N_Phi=N_Phi From d1f04a8d7848547c67ede75644f8ea6bb86ac2d9 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Tue, 3 Feb 2026 15:36:09 -0500 Subject: [PATCH 16/20] remove redundant part of test, and adjust tols --- tests/test_surfaces.py | 77 ++---------------------------------------- 1 file changed, 3 insertions(+), 74 deletions(-) diff --git a/tests/test_surfaces.py b/tests/test_surfaces.py index 9e91e4ec49..3252c01d24 100644 --- a/tests/test_surfaces.py +++ b/tests/test_surfaces.py @@ -157,7 +157,7 @@ def test_constant_offset_surface_circle(self): r_offset_surf = data["x_offset_surface"] r_surf = data["x"] dists = np.linalg.norm(r_surf - r_offset_surf, axis=1) - np.testing.assert_allclose(dists, 1, atol=1e-16) + np.testing.assert_allclose(dists, 1, atol=1e-14) R00_offset_ind = s_offset.R_basis.get_idx(M=0, N=0) R00_offset = s_offset.R_lmn[R00_offset_ind] R10_offset_ind = s_offset.R_basis.get_idx(M=1, N=0) @@ -174,7 +174,7 @@ def test_constant_offset_surface_circle(self): np.array([R00_offset_ind, R10_offset_ind]), ), 0, - atol=9e-15, + atol=1e-14, ) np.testing.assert_allclose( np.delete( @@ -182,7 +182,7 @@ def test_constant_offset_surface_circle(self): Zneg10_offset_ind, ), 0, - atol=9e-15, + atol=1e-14, ) grid_compute = LinearGrid(M=10, N=10) data = s.compute(["x", "e_theta", "e_zeta"], basis="rpz", grid=grid_compute) @@ -231,77 +231,6 @@ def fun(params): # ensure is jitable R_lmn, Z_lmn, data, _ = jit(fun)(s.params_dict) - s_offset = FourierRZToroidalSurface( - R_lmn=R_lmn, - Z_lmn=Z_lmn, - M=s.M, - N=s.N, - NFP=s.NFP, - sym=s.sym, - modes_R=data["transforms"]["R"].basis.modes[:, 1:], - modes_Z=data["transforms"]["Z"].basis.modes[:, 1:], - ) - - r_offset_surf = data["x_offset_surface"] - r_surf = data["x"] - dists = np.linalg.norm(r_surf - r_offset_surf, axis=1) - np.testing.assert_allclose(dists, 1, atol=1e-16) - R00_offset_ind = s_offset.R_basis.get_idx(M=0, N=0) - R00_offset = s_offset.R_lmn[R00_offset_ind] - R10_offset_ind = s_offset.R_basis.get_idx(M=1, N=0) - R10_offset = s_offset.R_lmn[R10_offset_ind] - Zneg10_offset_ind = s_offset.Z_basis.get_idx(M=-1, N=0) - Zneg10_offset = s_offset.Z_lmn[Zneg10_offset_ind] - - np.testing.assert_allclose(R00_offset, 10) - np.testing.assert_allclose(R10_offset, 2) - np.testing.assert_allclose(Zneg10_offset, -2) - np.testing.assert_allclose( - np.delete( - s_offset.R_lmn, - np.array([R00_offset_ind, R10_offset_ind]), - ), - 0, - atol=9e-15, - ) - np.testing.assert_allclose( - np.delete( - s_offset.Z_lmn, - Zneg10_offset_ind, - ), - 0, - atol=9e-15, - ) - grid_compute = LinearGrid(M=10, N=10) - data = s.compute(["x", "e_theta", "e_zeta"], basis="rpz", grid=grid_compute) - data_offset = s_offset.compute( - ["x", "e_theta", "e_zeta"], basis="rpz", grid=grid_compute - ) - dists = np.linalg.norm(data["x"] - data_offset["x"], axis=1) - np.testing.assert_allclose(dists, 1, atol=1e-16) - correct_data_offset = { - "e_theta": np.vstack( - ( - -2 * np.sin(grid_compute.nodes[:, 1]), - np.zeros_like(grid_compute.nodes[:, 1]), - -2 * np.cos(grid_compute.nodes[:, 1]), - ) - ).T, - "e_zeta": np.vstack( - ( - np.zeros_like(grid_compute.nodes[:, 1]), - data_offset["x"][:, 0], - np.zeros_like(grid_compute.nodes[:, 1]), - ) - ).T, - } - for key in ["e_theta", "e_zeta"]: - np.testing.assert_allclose( - correct_data_offset[key], - data_offset[key], - atol=1e-4, - err_msg=f"Failed test at comparison of {key}", - ) # make sure that the function is not recompiled with jax.log_compiles(): R_lmn, Z_lmn, data, _ = jit(fun)(s.params_dict) From e8112049b466a53a9f06629e6af55b6062732112 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Wed, 4 Feb 2026 15:12:16 -0500 Subject: [PATCH 17/20] adjust tols --- tests/test_surfaces.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_surfaces.py b/tests/test_surfaces.py index 3252c01d24..4d004f06cd 100644 --- a/tests/test_surfaces.py +++ b/tests/test_surfaces.py @@ -174,7 +174,7 @@ def test_constant_offset_surface_circle(self): np.array([R00_offset_ind, R10_offset_ind]), ), 0, - atol=1e-14, + atol=1e-13, ) np.testing.assert_allclose( np.delete( @@ -182,7 +182,7 @@ def test_constant_offset_surface_circle(self): Zneg10_offset_ind, ), 0, - atol=1e-14, + atol=1e-13, ) grid_compute = LinearGrid(M=10, N=10) data = s.compute(["x", "e_theta", "e_zeta"], basis="rpz", grid=grid_compute) @@ -190,7 +190,7 @@ def test_constant_offset_surface_circle(self): ["x", "e_theta", "e_zeta"], basis="rpz", grid=grid_compute ) dists = np.linalg.norm(data["x"] - data_offset["x"], axis=1) - np.testing.assert_allclose(dists, 1, atol=1e-16) + np.testing.assert_allclose(dists, 1, atol=1e-14) correct_data_offset = { "e_theta": np.vstack( ( From 6079bdd7d49ba8e3052f4df831d6afb2b399b6da Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Wed, 4 Feb 2026 16:15:15 -0500 Subject: [PATCH 18/20] attempt to fix docs --- desc/geometry/surface.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index b4571c66e5..4d4dce4f08 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -713,7 +713,8 @@ def constant_offset_surface( FourierRZToroidalSurface, created from fitting points offset from the input surface by the given constant offset. data : dict - dictionary containing the following data, in the cylindrical basis: + dictionary containing the following data, in the cylindrical basis, + as well as the transforms bases used to fit R and Z: ``n`` : (``grid.num_nodes`` x 3) array of the unit surface normal on the base_surface evaluated at the input ``grid`` ``x`` : (``grid.num_nodes`` x 3) array of coordinates on @@ -722,8 +723,6 @@ def constant_offset_surface( coordinates on the offset surface, corresponding to the ``x`` points on the base_surface (i.e. the points to which the offset surface was fit) - as well as the transforms bases used to fit R and Z. - Only returned if ``full_output`` is True info : tuple 2 element tuple containing residuals and number of iterations for each point. Only returned if ``full_output`` is True From 01015687fe8c9649acb04093c175da99b4cdeda5 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Wed, 4 Feb 2026 16:38:51 -0500 Subject: [PATCH 19/20] another doc fix attempt --- desc/geometry/surface.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 4d4dce4f08..603c847d9a 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -713,8 +713,7 @@ def constant_offset_surface( FourierRZToroidalSurface, created from fitting points offset from the input surface by the given constant offset. data : dict - dictionary containing the following data, in the cylindrical basis, - as well as the transforms bases used to fit R and Z: + dictionary containing the following data, in the cylindrical basis: ``n`` : (``grid.num_nodes`` x 3) array of the unit surface normal on the base_surface evaluated at the input ``grid`` ``x`` : (``grid.num_nodes`` x 3) array of coordinates on @@ -723,6 +722,8 @@ def constant_offset_surface( coordinates on the offset surface, corresponding to the ``x`` points on the base_surface (i.e. the points to which the offset surface was fit) + ``transforms`` : dict containing the Transform objects used to + fit R and Z of the info : tuple 2 element tuple containing residuals and number of iterations for each point. Only returned if ``full_output`` is True From 66f74d5883d95c464714ff6abe719a8c6293739b Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Thu, 5 Feb 2026 11:59:21 -0500 Subject: [PATCH 20/20] adjust maxiter --- desc/geometry/surface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 603c847d9a..97665150ea 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -1274,7 +1274,7 @@ def fun_jax(zeta_hat, theta, zeta): vecroot = jit( vmap( lambda x0, *p: root_scalar( - fun_jax, x0, jac=None, args=p, full_output=True, tol=1e-12 + fun_jax, x0, jac=None, args=p, full_output=True, tol=1e-12, maxiter=100 ) ) )