Source Guide

This page is a code-oriented map of NEO_JAX for readers who want to connect the theory and numerics pages to the implementation.

Module structure

The repository is organized into a small number of solver-facing modules:

Module

Responsibility

neo_jax.api

High-level public API such as neo_jax.run_neo() and convenience wrappers for boozmn files and in-memory Boozer objects.

neo_jax.config

User-facing configuration model.

neo_jax.io

boozmn loading and conversion from booz_xform-style objects.

neo_jax.fourier

Fourier reconstruction and derived geometric quantities.

neo_jax.surface

Surface initialization, spline construction, and \(B_{\min}\)/\(B_{\max}\) refinement.

neo_jax.geometry

Spline evaluation and Newton-based extremum refinement.

neo_jax.integrate

Field-line RHS, RK4 stepping, trapped-particle bookkeeping, and the JAX scan backend.

neo_jax.driver

Surface loop orchestration, scaling, diagnostics, and result assembly.

neo_jax.pipeline

VMEC→Boozer→NEO helper workflows.

Geometry loading

The boozmn reader is where the external geometry enters the solver:

def read_boozmn(
    path: str | Path,
    *,
    max_m_mode: int = 0,
    max_n_mode: int = 0,
    fluxs_arr: Optional[Sequence[int]] = None,
    extension: str | None = None,
) -> BoozerData:
    """Read a boozmn netCDF file and return BoozerData.

    This function follows the packing conventions used in STELLOPT's
    read_boozer_mod, but only retains the surfaces requested.
    """
    _require_netcdf4()
    booz_path = resolve_boozmn_path(path, extension)

    with netCDF4.Dataset(booz_path) as ds:  # type: ignore[union-attr]
        nfp = int(ds.variables["nfp_b"][:])
        ns_b = int(ds.variables["ns_b"][:])
        mboz_b = int(ds.variables["mboz_b"][:])
        nboz_b = int(ds.variables["nboz_b"][:])

        ixm_b = np.array(ds.variables["ixm_b"][:], dtype=int)
        ixn_b = np.array(ds.variables["ixn_b"][:], dtype=int)

        iota_b = np.array(ds.variables["iota_b"][:], dtype=float)
        buco_b = np.array(ds.variables["buco_b"][:], dtype=float)
        bvco_b = np.array(ds.variables["bvco_b"][:], dtype=float)
        pres_b = np.array(ds.variables["pres_b"][:], dtype=float) if "pres_b" in ds.variables else None

        rmnc_raw = np.array(ds.variables["rmnc_b"][:], dtype=float)
        zmns_raw = np.array(ds.variables["zmns_b"][:], dtype=float)
        pmns_raw = np.array(ds.variables["pmns_b"][:], dtype=float)
        bmnc_raw = np.array(ds.variables["bmnc_b"][:], dtype=float)
        gmn_raw = np.array(ds.variables["gmn_b"][:], dtype=float) if "gmn_b" in ds.variables else None

        if "jlist" in ds.variables:
            jlist = np.array(ds.variables["jlist"][:], dtype=int)
        else:
            jlist = np.arange(1, rmnc_raw.shape[0] + 1, dtype=int)

    pack_len = rmnc_raw.shape[0]
    rmnc_pack = _transpose_if_needed(rmnc_raw, pack_len)
    zmns_pack = _transpose_if_needed(zmns_raw, pack_len)
    pmns_pack = _transpose_if_needed(pmns_raw, pack_len)
    bmnc_pack = _transpose_if_needed(bmnc_raw, pack_len)
    gmn_pack = _transpose_if_needed(gmn_raw, pack_len) if gmn_raw is not None else None

    max_m = max_m_mode if max_m_mode > 0 else mboz_b - 1
    max_n = max_n_mode if max_n_mode > 0 else nboz_b * nfp
    mode_mask = _select_modes(ixm_b, ixn_b, max_m, max_n)

    ixm = ixm_b[mode_mask]
    ixn = ixn_b[mode_mask]
    mode0 = np.where((ixm_b == 0) & (ixn_b == 0))[0]
    mode0_idx = int(mode0[0]) if len(mode0) else None

    pack_index = {int(surf): idx for idx, surf in enumerate(jlist)}

    if fluxs_arr is not None:
        surfaces = list(fluxs_arr)
    else:
        surfaces = list(jlist)

    rmnc = []
    zmns = []
    lmns = []
    bmnc = []
    es = []
    iota = []
    curr_pol = []
    curr_tor = []
    pprime = []
    sqrtg00 = []

    hs = 1.0 / (ns_b - 1)
    for surf in surfaces:
        pack_idx = pack_index.get(int(surf))
        if pack_idx is None:
            raise ValueError(f"Surface {surf} not found in boozmn jlist")

        rmnc.append(rmnc_pack[pack_idx, mode_mask])
        zmns.append(zmns_pack[pack_idx, mode_mask])
        lmns.append(-pmns_pack[pack_idx, mode_mask] * nfp / (2.0 * math.pi))
        bmnc.append(bmnc_pack[pack_idx, mode_mask])

        es.append((surf - 1.5) * hs)
        iota.append(iota_b[surf - 1])
        curr_pol.append(bvco_b[surf - 1])
        curr_tor.append(buco_b[surf - 1])
        if pres_b is not None and surf < ns_b:
            pprime.append((pres_b[surf] - pres_b[surf - 1]) / hs)
        else:
            pprime.append(0.0)
        if gmn_pack is not None and mode0_idx is not None:
            sqrtg00.append(float(gmn_pack[pack_idx, mode0_idx]))
        else:
            sqrtg00.append(0.0)

    return BoozerData(
        rmnc=np.asarray(rmnc),
        zmns=np.asarray(zmns),
        lmns=np.asarray(lmns),
        bmnc=np.asarray(bmnc),
        ixm=np.asarray(ixm),
        ixn=np.asarray(ixn),
        es=np.asarray(es),
        iota=np.asarray(iota),
        curr_pol=np.asarray(curr_pol),
        curr_tor=np.asarray(curr_tor),
        nfp=nfp,
        pprime=np.asarray(pprime),
        sqrtg00=np.asarray(sqrtg00),
    )

This function:

  • resolves the requested file path

  • loads the Boozer Fourier coefficients and current profiles

  • maps the boozmn convention to the internal neo_jax.BoozerData container

  • computes the normalized toroidal-flux coordinate s used by the public API

Surface initialization

For each selected surface, NEO_JAX reconstructs the geometry, derives the metric quantities, builds the spline representation, and refines \(B_{\min}\) and \(B_{\max}\):

def init_surface(
    theta_arr: Array,
    phi_arr: Array,
    coeffs: Dict[str, Array],
    ixm: Array,
    ixn: Array,
    nfp: int,
    max_m_mode: int,
    max_n_mode: int,
    curr_pol: Array,
    curr_tor: Array,
    iota: Array,
    grid: Dict[str, float],
    calc_cur: bool = False,
    use_jax: bool = False,
    skip_mask: bool = False,
) -> SurfaceData:
    """Initialize a single flux surface: Fourier sums, derived fields, splines, B min/max."""
    fourier = fourier_sums(
        theta_arr,
        phi_arr,
        coeffs["rmnc"],
        coeffs["zmns"],
        coeffs["lmns"],
        coeffs["bmnc"],
        ixm,
        ixn,
        nfp=nfp,
        max_m_mode=max_m_mode,
        max_n_mode=max_n_mode,
        skip_mask=skip_mask,
        lasym=bool(coeffs.get("lasym", False)),
        rmns=coeffs.get("rmns"),
        zmnc=coeffs.get("zmnc"),
        lmnc=coeffs.get("lmnc"),
        bmns=coeffs.get("bmns"),
    )

    derived = derived_quantities(fourier, curr_pol=curr_pol, curr_tor=curr_tor, iota=iota)
    fields = {**fourier, **derived}

    splines = build_splines(fields, grid["theta_int"], grid["phi_int"], grid["mt"], grid["mp"], calc_cur)

    # Find initial min/max from grid, then refine with Newton.
    # Fortran's MINLOC/MAXLOC traverse arrays in column-major order
    # (theta index varies fastest). Use exact extrema and pick the
    # first index in Fortran order when there are ties.
    b = fields["b"]
    b_max = jnp.max(b)
    b_min = jnp.min(b)

    if use_jax:
        max_i, max_j = _select_extremum_index_jax(b, find_max=True)
        min_i, min_j = _select_extremum_index_jax(b, find_max=False)
    else:
        max_i, max_j = _select_extremum_index(
            b,
            theta_arr,
            phi_arr,
            ixm,
            ixn,
            coeffs["bmnc"],
            max_m_mode,
            max_n_mode,
            find_max=True,
        )
        min_i, min_j = _select_extremum_index(
            b,
            theta_arr,
            phi_arr,
            ixm,
            ixn,
            coeffs["bmnc"],
            max_m_mode,
            max_n_mode,
            find_max=False,
        )

    theta_bmin = theta_arr[min_i]
    phi_bmin = phi_arr[min_j]
    theta_bmax = theta_arr[max_i]
    phi_bmax = phi_arr[max_j]

    theta_bmin, phi_bmin, _it_min, _err_min = neo_zeros2d(
        theta_bmin, phi_bmin, 1.0e-10, 100, splines["b_spl"], grid
    )
    theta_bmax, phi_bmax, _it_max, _err_max = neo_zeros2d(
        theta_bmax, phi_bmax, 1.0e-10, 100, splines["b_spl"], grid
    )

    # Evaluate B at refined points.
    from .geometry import neo_eval

    b_min, *_ = neo_eval(
        theta_bmin, phi_bmin,
        splines["b_spl"],
        splines["g_spl"],
        splines["k_spl"],
        splines["p_spl"],
        splines.get("q_spl"),
        grid,
    )
    b_max, *_ = neo_eval(
        theta_bmax, phi_bmax,
        splines["b_spl"],
        splines["g_spl"],
        splines["k_spl"],
        splines["p_spl"],
        splines.get("q_spl"),
        grid,
    )

    bmref = b_max

    return SurfaceData(
        b_min=b_min,
        b_max=b_max,
        theta_bmin=theta_bmin,
        phi_bmin=phi_bmin,
        theta_bmax=theta_bmax,
        phi_bmax=phi_bmax,
        bmref=bmref,
        fields=fields,
        splines=splines,
    )

Field-line RHS

The core continuous model enters through neo_jax.integrate.rhs_bo1():

def rhs_bo1(phi: Array, y: Array, state: RhsState, env: RhsEnv) -> Tuple[Array, RhsState]:
    """Right-hand side for the field-line ODE (port of rhs_bo1.f90)."""
    theta = y[0]

    bmod, gval, geodcu, pardeb, _qval = neo_eval(
        theta,
        phi,
        env.splines["b_spl"],
        env.splines["g_spl"],
        env.splines["k_spl"],
        env.splines["p_spl"],
        env.splines.get("q_spl"),
        env.grid,
    )

    bmodm2 = 1.0 / (bmod * bmod)
    bmodm3 = bmodm2 / bmod
    bra = bmod / env.bmod0

    ipass = jnp.where((pardeb * state.pard0 <= 0) & (pardeb > 0), 1, 0).astype(state.isw.dtype)
    ipmax = jnp.where(
        (state.ipmax == 0) & (pardeb * state.pard0 <= 0) & (pardeb < 0),
        1,
        state.ipmax,
    ).astype(state.isw.dtype)

    pard0 = pardeb

    dery = jnp.zeros_like(y)
    dery = dery.at[0].set(env.iota)
    dery = dery.at[1].set(bmodm2)
    dery = dery.at[2].set(bmodm2 * gval)
    dery = dery.at[3].set(geodcu * bmodm3)

    eta_pos = env.eta > 0
    eta_safe = jnp.where(eta_pos, env.eta, jnp.asarray(1.0, dtype=env.eta.dtype))
    inv_eta = jnp.where(eta_pos, 1.0 / eta_safe, 0.0)
    sqeta = jnp.sqrt(eta_safe)
    inv_sqeta = jnp.where(eta_pos, 1.0 / sqeta, 0.0)

    subsq = 1.0 - bra * inv_eta
    mask = (subsq > 0) & eta_pos

    safe_subsq = jnp.where(mask, subsq, 0.0)
    sq = jnp.sqrt(safe_subsq) * bmodm2
    p_i = jnp.where(mask, sq, 0.0)
    p_h = jnp.where(mask, sq * (4.0 / bra - inv_eta) * geodcu * inv_sqeta, 0.0)

    # Update particle state
    one_i = jnp.array(1, dtype=state.isw.dtype)
    two_i = jnp.array(2, dtype=state.isw.dtype)
    zero_i = jnp.array(0, dtype=state.isw.dtype)
    isw = jnp.where(mask, one_i, jnp.where(state.isw == 1, two_i, jnp.where(state.isw == 2, two_i, zero_i)))
    icount = state.icount + mask.astype(state.icount.dtype)
    ipa = state.ipa + ipass * mask.astype(state.ipa.dtype)

    dery = dery.at[NPQ : NPQ + env.eta.shape[0]].set(p_i)
    dery = dery.at[NPQ + env.eta.shape[0] : NPQ + 2 * env.eta.shape[0]].set(p_h)

    new_state = RhsState(isw=isw, ipa=ipa, icount=icount, ipmax=ipmax, pard0=pard0)
    return dery, new_state

This is where the state vector \((\theta, y_2, y_3, y_4, I_j, H_j)\) is advanced and where the trapped particle masks are updated.

JAX scan backend

The compiled backend lives in neo_jax.integrate.flint_bo_jax():

def flint_bo_jax(
    surface,
    params: FlintParams,
    env: RhsEnv,
    nfp: int,
    rt0: float | None = None,
    *,
    Rmajor: float | None = None,
    diagnostic_callback=None,
    diagnostic_trap_callback=None,
    diagnostic_snapshot: tuple[int, int] | None = None,
    diagnostic_snapshot_callback=None,
    convergence_callback=None,
    convergence_period_callback=None,
    convergence_step_callback=None,
    convergence_reset_callback=None,
    strict_parity: bool = False,
    skip_rational_correction: bool = False,
):
    """JAX-friendly integration loop with rational-surface correction."""
    if rt0 is None:
        if Rmajor is None:
            raise ValueError("Either rt0 or Rmajor must be provided.")
        rt0 = float(Rmajor)
    elif Rmajor is not None and float(Rmajor) != float(rt0):
        raise ValueError("rt0 and Rmajor must match if both are provided.")
    npart = params.npart
    multra = params.multra
    ndim = NPQ + 2 * npart

    # Particle grids
    etamin = surface.b_min / surface.bmref
    etamax = surface.b_max / surface.bmref
    heta = (etamax - etamin) / (npart - 1)
    etamin = etamin + heta / 2.0
    eta = etamin + heta * jnp.arange(npart)

    env = RhsEnv(
        splines=env.splines,
        grid=env.grid,
        eta=eta,
        bmod0=surface.bmref,
        iota=env.iota,
        curr_pol=env.curr_pol,
        curr_tor=env.curr_tor,
    )

    coeps = jnp.pi * rt0 * rt0 * heta / (8.0 * jnp.sqrt(2.0))
    if env.curr_pol is None or env.curr_tor is None:
        j_iota_i = 0.0
    else:
        j_iota_i = env.curr_pol + env.iota * env.curr_tor

    y = jnp.zeros(ndim)
    y = y.at[0].set(surface.theta_bmax)
    phi0 = surface.phi_bmax
    phi = phi0

    state = RhsState(
        isw=jnp.zeros(npart, dtype=jnp.int32),
        ipa=jnp.zeros(npart, dtype=jnp.int32),
        icount=jnp.zeros(npart, dtype=jnp.int32),
        ipmax=jnp.array(0, dtype=jnp.int32),
        pard0=jnp.array(0.0),
    )

    _bmod, _gval, _geodcu, pard0, _qval = neo_eval(
        surface.theta_bmax,
        surface.phi_bmax,
        env.splines["b_spl"],
        env.splines["g_spl"],
        env.splines["k_spl"],
        env.splines["p_spl"],
        env.splines.get("q_spl"),
        env.grid,
    )
    state = RhsState(state.isw, state.ipa, state.icount, state.ipmax, pard0)

    iswst = jnp.zeros(npart, dtype=jnp.int32)
    bigint = jnp.zeros(multra)
    adimax = jnp.array(0.0)
    aditot = jnp.array(0.0)

    hphi = 2.0 * jnp.pi / (params.nstep_per * nfp)
    theta0 = surface.theta_bmax

    theta_d_min = 2.0 * jnp.pi
    n_iota = jnp.array(1, dtype=jnp.int32)
    m_iota = jnp.array(1, dtype=jnp.int32)
    iota_bar_fp = jnp.array(0.0)

    nstep_max_c = jnp.array(params.nstep_max, dtype=jnp.int32)
    hit_rat = jnp.array(0, dtype=jnp.int32)
    exist_first_ratfl = jnp.array(0, dtype=jnp.int32)
    nfp_rat = jnp.array(0, dtype=jnp.int32)
    nfl_rat = jnp.array(0, dtype=jnp.int32)
    delta_theta_rat = jnp.array(0.0)
    n_gap = jnp.array(0, dtype=jnp.int32)

    stop = jnp.array(False)
    n = jnp.array(0, dtype=jnp.int32)

    if diagnostic_snapshot is None:
        snap_n = None
        snap_j = None
    else:
        snap_n = jnp.array(diagnostic_snapshot[0], dtype=jnp.int32)
        snap_j = jnp.array(diagnostic_snapshot[1], dtype=jnp.int32)

    def integrate_period(carry, emit_diag: bool, step_index):
        phi, y, state, iswst, bigint, adimax, aditot = carry

        def rhs_inline(phi_local: Array, y_local: Array, state_local: RhsState) -> Tuple[Array, RhsState]:
            """Inline RHS evaluation to help XLA fuse neo_eval + RK4."""
            theta = y_local[0]

            bmod, gval, geodcu, pardeb, _qval = neo_eval(
                theta,
                phi_local,
                env.splines["b_spl"],
                env.splines["g_spl"],
                env.splines["k_spl"],
                env.splines["p_spl"],
                env.splines.get("q_spl"),
                env.grid,
            )

            bmodm2 = 1.0 / (bmod * bmod)
            bmodm3 = bmodm2 / bmod
            bra = bmod / env.bmod0

            ipass = jnp.where(
                (pardeb * state_local.pard0 <= 0) & (pardeb > 0), 1, 0
            ).astype(state_local.isw.dtype)
            ipmax = jnp.where(
                (state_local.ipmax == 0) & (pardeb * state_local.pard0 <= 0) & (pardeb < 0),
                1,
                state_local.ipmax,
            ).astype(state_local.isw.dtype)

            pard0 = pardeb

            dery = jnp.zeros_like(y_local)
            dery = dery.at[0].set(env.iota)
            dery = dery.at[1].set(bmodm2)
            dery = dery.at[2].set(bmodm2 * gval)
            dery = dery.at[3].set(geodcu * bmodm3)

            eta_pos = env.eta > 0
            eta_safe = jnp.where(eta_pos, env.eta, jnp.asarray(1.0, dtype=env.eta.dtype))
            inv_eta = jnp.where(eta_pos, 1.0 / eta_safe, 0.0)
            sqeta = jnp.sqrt(eta_safe)
            inv_sqeta = jnp.where(eta_pos, 1.0 / sqeta, 0.0)

            subsq = 1.0 - bra * inv_eta
            mask = (subsq > 0) & eta_pos

            safe_subsq = jnp.where(mask, subsq, 0.0)
            sq = jnp.sqrt(safe_subsq) * bmodm2
            p_i = jnp.where(mask, sq, 0.0)
            p_h = jnp.where(mask, sq * (4.0 / bra - inv_eta) * geodcu * inv_sqeta, 0.0)

            one_i = jnp.array(1, dtype=state_local.isw.dtype)
            two_i = jnp.array(2, dtype=state_local.isw.dtype)
            zero_i = jnp.array(0, dtype=state_local.isw.dtype)
            isw = jnp.where(
                mask,
                one_i,
                jnp.where(
                    state_local.isw == 1,
                    two_i,
                    jnp.where(state_local.isw == 2, two_i, zero_i),
                ),
            )
            icount = state_local.icount + mask.astype(state_local.icount.dtype)
            ipa = state_local.ipa + ipass * mask.astype(state_local.ipa.dtype)

            dery = dery.at[NPQ : NPQ + env.eta.shape[0]].set(p_i)
            dery = dery.at[NPQ + env.eta.shape[0] : NPQ + 2 * env.eta.shape[0]].set(p_h)

            new_state = RhsState(isw=isw, ipa=ipa, icount=icount, ipmax=ipmax, pard0=pard0)
            return dery, new_state

        def rk4_step_inline(
            phi_local: Array, y_local: Array, state_local: RhsState
        ) -> Tuple[Array, Array, RhsState]:
            """Inline RK4 to increase fusion with neo_eval."""
            hh = hphi / 2.0
            h6 = hphi / 6.0

            k1, state1 = rhs_inline(phi_local, y_local, state_local)
            y1 = y_local + hh * k1

            k2, state2 = rhs_inline(phi_local + hh, y1, state1)
            y2 = y_local + hh * k2

            k3, state3 = rhs_inline(phi_local + hh, y2, state2)
            y3 = y_local + hphi * k3

            k4, state4 = rhs_inline(phi_local + hphi, y3, state3)

            y_new = y_local + h6 * (k1 + k4 + 2.0 * (k2 + k3))
            phi_new = phi_local + hphi
            return phi_new, y_new, state4

        def inner_step(j, inner):
            phi, y, state, iswst, bigint, adimax, aditot = inner
            if strict_parity:
                phi, y, state = rk4_step(phi, y, state, env, hphi)
            else:
                phi, y, state = rk4_step_inline(phi, y, state)
            p_i = y[NPQ : NPQ + npart]
            p_h = y[NPQ + npart : NPQ + 2 * npart]
            if diagnostic_callback is not None and emit_diag:
                mask2 = state.isw == 2
                event_mask = mask2 & (iswst == 1)
                safe_pi = jnp.where(p_i == 0, jnp.array(1.0, dtype=p_i.dtype), p_i)
                add_on = (p_h * p_h) / safe_pi * iswst

                def _emit(_):
                    jax.debug.callback(diagnostic_callback, event_mask, state.icount, state.ipa, add_on, ordered=True)
                    return None

                _ = jax.lax.cond(jnp.any(event_mask), _emit, lambda _: None, operand=None)
            if diagnostic_trap_callback is not None and emit_diag:
                mask2 = state.isw == 2
                event_mask = mask2 & (iswst == 1)

                def _emit_trap(_):
                    jax.debug.callback(
                        diagnostic_trap_callback,
                        event_mask,
                        state.isw,
                        iswst,
                        p_i,
                        p_h,
                        state.icount,
                        state.ipa,
                        phi,
                        j,
                        step_index,
                        ordered=True,
                    )
                    return None

                _ = jax.lax.cond(jnp.any(event_mask), _emit_trap, lambda _: None, operand=None)
            if diagnostic_snapshot_callback is not None and emit_diag and snap_n is not None and snap_j is not None:
                def _emit_snap(_):
                    jax.debug.callback(
                        diagnostic_snapshot_callback,
                        state.isw,
                        iswst,
                        p_i,
                        p_h,
                        state.icount,
                        state.ipa,
                        phi,
                        j,
                        step_index,
                        ordered=True,
                    )
                    return None

                snap_cond = (step_index == snap_n) & (j == snap_j)
                _ = jax.lax.cond(snap_cond, _emit_snap, lambda _: None, operand=None)
            if convergence_step_callback is not None:
                jax.debug.callback(
                    convergence_step_callback,
                    step_index + 1,
                    j + 1,
                    state.ipmax,
                    state.isw,
                    state.ipa,
                    p_i,
                    ordered=True,
                )

            if strict_parity:
                state, iswst, p_i, p_h, bigint, adimax = _process_trapped(
                    state, iswst, p_i, p_h, bigint, adimax, multra
                )
            else:
                # Inline trapped-particle update to improve fusion in scan body.
                mask2 = state.isw == 2
                m_cl = jnp.clip(state.ipa, 1, multra).astype(state.ipa.dtype)

                def body(i, carry):
                    bigint_acc, adimax_acc = carry

                    def add_fn(carry):
                        bigint_acc, adimax_acc = carry
                        safe_pi = jnp.where(p_i[i] == 0, jnp.array(1.0, dtype=p_i.dtype), p_i[i])
                        add_on = (p_h[i] * p_h[i]) / safe_pi * iswst[i]
                        idx = m_cl[i] - 1
                        bigint_acc = bigint_acc.at[idx].add(add_on)
                        adimax_acc = jnp.where(state.ipa[i] == 1, p_i[i], adimax_acc)
                        return bigint_acc, adimax_acc

                    return jax.lax.cond(mask2[i], add_fn, lambda c: c, carry)

                bigint, adimax = jax.lax.fori_loop(0, p_i.shape[0], body, (bigint, adimax))

                iswst = jnp.where(mask2, 1, iswst)
                p_h = jnp.where(mask2, 0.0, p_h)
                p_i = jnp.where(mask2, 0.0, p_i)
                zero_int = jnp.zeros_like(state.isw)
                isw = jnp.where(mask2, zero_int, state.isw)
                icount = jnp.where(mask2, zero_int, state.icount)
                ipa = jnp.where(mask2, zero_int, state.ipa)
                state = RhsState(isw, ipa, icount, state.ipmax, state.pard0)

            y = y.at[NPQ : NPQ + npart].set(p_i)
            y = y.at[NPQ + npart : NPQ + 2 * npart].set(p_h)

            aditot = jnp.where(state.ipmax == 1, aditot + adimax, aditot)
            ipmax = jnp.where(
                state.ipmax == 1, jnp.array(0, dtype=state.ipmax.dtype), state.ipmax
            )
            state = RhsState(state.isw, state.ipa, state.icount, ipmax, state.pard0)
            return (phi, y, state, iswst, bigint, adimax, aditot)

        return jax.lax.fori_loop(0, params.nstep_per, inner_step, carry)

    def convergence_row(n_val, bigint_val, y_val, aditot_val):
        epspar_check = coeps * bigint_val * y_val[1] / (y_val[2] * y_val[2])
        epstot_check = jnp.sum(epspar_check)
        return jnp.asarray(
            [
                n_val.astype(y_val.dtype),
                epstot_check,
                y_val[3],
                y_val[NPQ + npart - 1] / y_val[1],
                aditot_val / y_val[1],
            ]
        )

    def emit_convergence(n_val, bigint_val, y_val, aditot_val):
        if convergence_callback is not None:
            row = convergence_row(n_val, bigint_val, y_val, aditot_val)
            jax.debug.callback(convergence_callback, row, ordered=True)
        if convergence_period_callback is not None:
            epspar_check = coeps * bigint_val * y_val[1] / (y_val[2] * y_val[2])
            epstot_check = jnp.sum(epspar_check)
            period_row = jnp.asarray(
                [
                    n_val.astype(y_val.dtype),
                    epstot_check,
                    y_val[3],
                    y_val[NPQ + npart - 1] / y_val[1],
                    y_val[1],
                ]
            )
            jax.debug.callback(convergence_period_callback, period_row, ordered=True)

    def update_theta_min(n_val, theta, theta_d_min, n_iota, m_iota, iota_bar_fp):
        twopi = 2.0 * jnp.pi

        def body(_):
            theta_rs = theta - theta0
            iota_bar_fp_new = jnp.where(n_val == 1, theta_rs / twopi, iota_bar_fp)
            m = jnp.floor(theta_rs / twopi)
            theta_rs_mod = theta_rs - m * twopi
            theta_d = jnp.where(theta_rs_mod <= jnp.pi, theta_rs_mod, theta_rs_mod - twopi)
            update = jnp.abs(theta_d) < jnp.abs(theta_d_min)
            theta_d_min_new = jnp.where(update, theta_d, theta_d_min)
            n_iota_new = jnp.where(update, n_val, n_iota)
            m_iota_new = jnp.where(
                update, jnp.where(theta_d >= 0, m.astype(jnp.int32), (m + 1).astype(jnp.int32)), m_iota
            )
            return theta_d_min_new, n_iota_new, m_iota_new, iota_bar_fp_new

        return jax.lax.cond(
            n_val <= params.nstep_min, body, lambda _: (theta_d_min, n_iota, m_iota, iota_bar_fp), operand=None
        )

    def update_rational(
        n_val,
        theta_d_min,
        n_iota,
        iota_bar_fp,
        nstep_max_c,
        hit_rat,
        exist_first_ratfl,
        nfp_rat,
        nfl_rat,
        delta_theta_rat,
        n_gap,
    ):
        twopi = 2.0 * jnp.pi

        def body(_):
            theta_d_min_safe = jnp.where(theta_d_min == 0.0, 1.0e-12, theta_d_min)
            theta_gap = twopi / n_iota
            n_gap_new = n_iota * jnp.floor(jnp.abs(theta_gap / theta_d_min_safe)).astype(jnp.int32)
            nstep_max_c_new = jnp.where(
                n_gap_new > params.nstep_min,
                n_gap_new,
                n_gap_new * jnp.ceil(params.nstep_min / n_gap_new).astype(jnp.int32),
            )

            hit_rat_new = jnp.where(nstep_max_c_new > params.nstep_max, 1, 0).astype(jnp.int32)
            nfp_rat_new = jnp.where(
                (hit_rat_new == 1) & (iota_bar_fp != 0.0),
                jnp.ceil(1.0 / params.acc_req / iota_bar_fp).astype(jnp.int32),
                0,
            )
            nfp_rat_new = jnp.where(
                (hit_rat_new == 1) & (nfp_rat_new % n_iota != 0),
                nfp_rat_new + n_iota - (nfp_rat_new % n_iota),
                nfp_rat_new,
            )

            exist_first_ratfl_new = jnp.where(nfp_rat_new >= params.nstep_min, 1, 0).astype(jnp.int32)
            nstep_max_c_new = jnp.where(exist_first_ratfl_new == 1, nfp_rat_new, nstep_max_c_new)

            nfl_rat_new = jnp.where(
                hit_rat_new == 1,
                jnp.ceil(params.no_bins / n_iota).astype(jnp.int32),
                nfl_rat,
            )
            delta_theta_rat_new = jnp.where(
                hit_rat_new == 1,
                theta_gap / (nfl_rat_new + 1),
                delta_theta_rat,
            )

            hit_rat_new = jnp.where(params.calc_nstep_max == 1, 0, hit_rat_new).astype(jnp.int32)

            return (
                nstep_max_c_new,
                hit_rat_new,
                exist_first_ratfl_new,
                nfp_rat_new,
                nfl_rat_new,
                delta_theta_rat_new,
                n_gap_new,
            )

        return jax.lax.cond(
            n_val == params.nstep_min,
            body,
            lambda _: (nstep_max_c, hit_rat, exist_first_ratfl, nfp_rat, nfl_rat, delta_theta_rat, n_gap),
            operand=None,
        )

    def scan_body(carry, _):
        (
            phi,
            y,
            state,
            iswst,
            bigint,
            adimax,
            aditot,
            n,
            theta_d_min,
            n_iota,
            m_iota,
            iota_bar_fp,
            nstep_max_c,
            hit_rat,
            exist_first_ratfl,
            nfp_rat,
            nfl_rat,
            delta_theta_rat,
            n_gap,
            stop,
        ) = carry

        def do_step(_carry):
            (
                phi,
                y,
                state,
                iswst,
                bigint,
                adimax,
                aditot,
                n,
                theta_d_min,
                n_iota,
                m_iota,
                iota_bar_fp,
                nstep_max_c,
                hit_rat,
                exist_first_ratfl,
                nfp_rat,
                nfl_rat,
                delta_theta_rat,
                n_gap,
                stop,
            ) = _carry

            phi, y, state, iswst, bigint, adimax, aditot = integrate_period(
                (phi, y, state, iswst, bigint, adimax, aditot),
                True,
                n,
            )
            n_new = n + 1
            emit_convergence(n_new.astype(y.dtype), bigint, y, aditot)
            theta = y[0]
            theta_d_min_new, n_iota_new, m_iota_new, iota_bar_fp_new = update_theta_min(
                n_new, theta, theta_d_min, n_iota, m_iota, iota_bar_fp
            )
            (
                nstep_max_c_new,
                hit_rat_new,
                exist_first_ratfl_new,
                nfp_rat_new,
                nfl_rat_new,
                delta_theta_rat_new,
                n_gap_new,
            ) = update_rational(
                n_new,
                theta_d_min_new,
                n_iota_new,
                iota_bar_fp_new,
                nstep_max_c,
                hit_rat,
                exist_first_ratfl,
                nfp_rat,
                nfl_rat,
                delta_theta_rat,
                n_gap,
            )

            stop_new = jnp.where(
                (params.calc_nstep_max == 0) & (n_new == nstep_max_c_new),
                True,
                False,
            )
            stop_new = jnp.where(
                (hit_rat_new == 1) & (exist_first_ratfl_new == 0), True, stop_new
            )
            return (
                phi,
                y,
                state,
                iswst,
                bigint,
                adimax,
                aditot,
                n_new,
                theta_d_min_new,
                n_iota_new,
                m_iota_new,
                iota_bar_fp_new,
                nstep_max_c_new,
                hit_rat_new,
                exist_first_ratfl_new,
                nfp_rat_new,
                nfl_rat_new,
                delta_theta_rat_new,
                n_gap_new,
                stop_new,
            )

        return jax.lax.cond(stop, lambda c: c, do_step, carry), None

    init_carry = (
        phi,
        y,
        state,
        iswst,
        bigint,
        adimax,
        aditot,
        n,
        theta_d_min,
        n_iota,
        m_iota,
        iota_bar_fp,
        nstep_max_c,
        hit_rat,
        exist_first_ratfl,
        nfp_rat,
        nfl_rat,
        delta_theta_rat,
        n_gap,
        stop,
    )

    final_carry, _ = jax.lax.scan(scan_body, init_carry, xs=None, length=params.nstep_max)
    (
        phi,
        y,
        state,
        iswst,
        bigint,
        adimax,
        aditot,
        n,
        theta_d_min,
        n_iota,
        m_iota,
        iota_bar_fp,
        nstep_max_c,
        hit_rat,
        exist_first_ratfl,
        nfp_rat,
        nfl_rat,
        delta_theta_rat,
        n_gap,
        stop,
    ) = final_carry

    y2 = y[1]
    y3 = y[2]
    y4 = y[3]
    y3npart = y[NPQ + npart - 1]

    def rational_correction(_):
        if convergence_reset_callback is not None:
            def _reset(_):
                jax.debug.callback(convergence_reset_callback, ordered=True)
                return None

            _ = jax.lax.cond(exist_first_ratfl == 0, _reset, lambda _: None, operand=None)

        def reset_accumulators(_):
            zero_bigint = jnp.zeros_like(bigint)
            return zero_bigint, jnp.array(0.0), jnp.array(0.0), jnp.array(0.0), jnp.array(0.0), jnp.array(0.0)

        def keep_accumulators(_):
            return bigint, aditot, y2, y3, y4, y3npart

        bigint0, aditot0, y20, y30, y40, y3npart0 = jax.lax.cond(
            exist_first_ratfl == 0, reset_accumulators, keep_accumulators, operand=None
        )

        nfl_count = jnp.maximum(nfl_rat + 1 - exist_first_ratfl, 0).astype(jnp.int32)

        def nfl_cond(carry):
            idx, *_ = carry
            return idx < nfl_count

        def nfl_body(carry):
            idx, bigint_acc, aditot_acc, y2_acc, y3_acc, y4_acc, y3npart_acc = carry
            nfl = idx + exist_first_ratfl

            phi_local = phi0
            y_local = jnp.zeros(ndim)
            y_local = y_local.at[0].set(theta0 + nfl * delta_theta_rat)

            state_local = RhsState(
                isw=jnp.zeros(npart, dtype=jnp.int32),
                ipa=jnp.zeros(npart, dtype=jnp.int32),
                icount=jnp.zeros(npart, dtype=jnp.int32),
                ipmax=jnp.array(0, dtype=jnp.int32),
                pard0=jnp.array(0.0),
            )

            iswst_local = jnp.zeros(npart, dtype=jnp.int32)
            bigint_s = jnp.zeros(multra)
            adimax_s = jnp.array(0.0)
            aditot_s = jnp.array(0.0)

            def n_cond(ncarry):
                n_idx, *_ = ncarry
                return n_idx < nfp_rat

            def n_body(ncarry):
                n_idx, phi_l, y_l, state_l, iswst_l, bigint_s, adimax_s, aditot_s = ncarry

                _bmod, _gval, _geodcu, pard0, _qval = neo_eval(
                    y_l[0],
                    phi_l,
                    env.splines["b_spl"],
                    env.splines["g_spl"],
                    env.splines["k_spl"],
                    env.splines["p_spl"],
                    env.splines.get("q_spl"),
                    env.grid,
                )
                state_l = RhsState(state_l.isw, state_l.ipa, state_l.icount, state_l.ipmax, pard0)

                phi_l, y_l, state_l, iswst_l, bigint_s, adimax_s, aditot_s = integrate_period(
                    (phi_l, y_l, state_l, iswst_l, bigint_s, adimax_s, aditot_s),
                    False,
                    n_idx,
                )
                emit_convergence((nfl * nfp_rat + n_idx + 1).astype(y_l.dtype), bigint_s, y_l, aditot_s)

                return (n_idx + 1, phi_l, y_l, state_l, iswst_l, bigint_s, adimax_s, aditot_s)

            n_init = (
                jnp.array(0, dtype=jnp.int32),
                phi_local,
                y_local,
                state_local,
                iswst_local,
                bigint_s,
                adimax_s,
                aditot_s,
            )
            n_final = jax.lax.while_loop(n_cond, n_body, n_init)
            _, phi_local, y_local, state_local, iswst_local, bigint_s, adimax_s, aditot_s = n_final

            y2_s = y_local[1]
            y3_s = y_local[2]
            y4_s = y_local[3]
            y3npart_s = y_local[NPQ + npart - 1]

            return (
                idx + 1,
                bigint_acc + bigint_s,
                aditot_acc + aditot_s,
                y2_acc + y2_s,
                y3_acc + y3_s,
                y4_acc + y4_s,
                y3npart_acc + y3npart_s,
            )

        nfl_init = (jnp.array(0, dtype=jnp.int32), bigint0, aditot0, y20, y30, y40, y3npart0)
        nfl_final = jax.lax.while_loop(nfl_cond, nfl_body, nfl_init)
        _, bigint_out, aditot_out, y2_out, y3_out, y4_out, y3npart_out = nfl_final

        return bigint_out, aditot_out, y2_out, y3_out, y4_out, y3npart_out

    def rational_skip(_):
        return bigint, aditot, y2, y3, y4, y3npart

    do_rational = (hit_rat == 1) & (nfp_rat > 0) & jnp.logical_not(jnp.asarray(skip_rational_correction))
    bigint, aditot, y2, y3, y4, y3npart = jax.lax.cond(
        do_rational, rational_correction, rational_skip, operand=None
    )

    nintfp = n

    epspar = jnp.zeros(multra)
    epstot = jnp.array(0.0)
    for m_cl in range(1, multra + 1):
        epspar = epspar.at[m_cl - 1].set(coeps * bigint[m_cl - 1] * y2 / (y3 * y3))
        epstot = epstot + epspar[m_cl - 1]

    ctrone = aditot / y2
    ctrtot = y3npart / y2
    bareph = (jnp.pi * ctrone) ** 2 / 8.0
    barept = (jnp.pi * ctrtot) ** 2 / 8.0
    drdpsi = y2 / y3
    yps = y4 * j_iota_i

    return {
        "epspar": epspar,
        "epstot": epstot,
        "ctrone": ctrone,
        "ctrtot": ctrtot,
        "bareph": bareph,
        "barept": barept,
        "drdpsi": drdpsi,
        "yps": yps,
        "y2": y2,
        "y3": y3,
        "y4": y4,
        "y3npart": y3npart,
        "bigint": bigint,
        "nintfp": nintfp,
        "hit_rat": hit_rat,
        "n_iota": n_iota,
        "m_iota": m_iota,
        "n_gap": n_gap,
        "final_n": jnp.where(hit_rat == 1, nfp_rat * (nfl_rat + 1), nintfp),
        "nfp_rat": nfp_rat,
        "nfl_rat": nfl_rat,
    }

This backend keeps the dominant loops on device, which is the basis for JIT reuse and batched surface evaluation.

Public configuration model

The public configuration surface is intentionally compact:

@dataclass(frozen=True)
class NeoConfig:
    """High-level configuration for NEO_JAX runs.

    Surface selections may be specified as:

    - Integers (1-based NEO surface indices), or
    - Floats in [0, 1], interpreted as normalized toroidal flux ``s``.

    ``max_rational_field_periods`` is a safeguard for near-zero-iota surfaces.
    Set it to ``0`` to disable the guard explicitly.

    ``rational_surface_policy`` controls what happens if the estimated
    rational-surface correction exceeds that limit:

    - ``"error"``: fail fast with a detailed diagnostic (default)
    - ``"approximate"``: skip the expensive rational-surface correction and
      return the base integration result with explicit diagnostics
    """

    surfaces: Sequence[int | float] | None = None
    theta_n: int = 64
    phi_n: int = 64
    max_m_mode: int = 0
    max_n_mode: int = 0
    npart: int = 40
    multra: int = 2
    acc_req: float = 0.02
    no_bins: int = 50
    nstep_per: int = 20
    nstep_min: int = 200
    nstep_max: int = 500
    calc_nstep_max: int = 0
    max_rational_field_periods: Optional[int] = 100_000
    rational_surface_policy: str = "error"
    ref_swi: int = 2
    write_progress: bool = False
    write_diagnostic: bool = False

    @classmethod
    def from_control(cls, control: ControlParams) -> "NeoConfig":
        return cls(
            surfaces=control.fluxs_arr,
            theta_n=control.theta_n,
            phi_n=control.phi_n,
            max_m_mode=control.max_m_mode,
            max_n_mode=control.max_n_mode,
            npart=control.npart,
            multra=control.multra,
            acc_req=control.acc_req,
            no_bins=control.no_bins,
            nstep_per=control.nstep_per,
            nstep_min=control.nstep_min,
            nstep_max=control.nstep_max,
            calc_nstep_max=control.calc_nstep_max,
            ref_swi=control.ref_swi,
            write_progress=bool(control.write_progress),
            write_diagnostic=bool(control.write_diagnostic),
        )

    def to_control(self, *, in_file: str = "boozmn", out_file: str = "neo_out") -> ControlParams:
        """Convert to a ControlParams object (for CLI compatibility)."""
        return ControlParams(
            in_file=in_file,
            out_file=out_file,
            fluxs_arr=list(self.surfaces) if self.surfaces is not None else None,
            theta_n=self.theta_n,
            phi_n=self.phi_n,
            max_m_mode=self.max_m_mode,
            max_n_mode=self.max_n_mode,
            npart=self.npart,
            multra=self.multra,
            acc_req=self.acc_req,
            no_bins=self.no_bins,
            nstep_per=self.nstep_per,
            nstep_min=self.nstep_min,
            nstep_max=self.nstep_max,
            calc_nstep_max=self.calc_nstep_max,
            eout_swi=1,
            lab_swi=0,
            inp_swi=0,
            ref_swi=self.ref_swi,
            write_progress=int(self.write_progress),
            write_output_files=0,
            spline_test=0,
            write_integrate=0,
            write_diagnostic=int(self.write_diagnostic),
            calc_cur=0,
            cur_file="neo_cur",
            npart_cur=0,
            alpha_cur=0.0,
            write_cur_inte=0,
        )

For user-level guidance on these fields, see Configuration and Runtime Controls.

Pipeline entrypoints

The reusable VMEC→Boozer→NEO path is built in neo_jax.pipeline:

def build_vmec_boozer_neo_jax(
    vmec_run: Any,
    *,
    booz_kwargs: dict | None = None,
    neo_config: NeoConfig | None = None,
    jit: bool = True,
):
    """Return a callable `solve(state)` for the JAX-native VMEC→Boozer→NEO path.

    This precomputes Boozer constants and surface selections so the returned
    function is suitable for repeated calls (and optional JIT).
    """
    booz_kwargs = booz_kwargs or {}
    cfg = neo_config or NeoConfig()

    try:
        import jax
        import jax.numpy as jnp
        from booz_xform_jax.jax_api import prepare_booz_xform_constants, booz_xform_jax_impl
    except ImportError as exc:  # pragma: no cover
        raise ImportError("booz_xform_jax is required for build_vmec_boozer_neo_jax") from exc

    try:
        from vmec_jax.booz_input import booz_xform_inputs_from_state
        from vmec_jax.energy import flux_profiles_from_indata
        from vmec_jax.profiles import eval_profiles
        from vmec_jax.vmec_tomnsp import vmec_trig_tables
    except ImportError as exc:  # pragma: no cover
        raise ImportError("vmec_jax with booz_input is required for build_vmec_boozer_neo_jax") from exc

    from .driver import run_neo_from_boozer_jax
    from .io import booz_xform_to_boozerdata_jax

    # Precompute static Boozer constants from the current state.
    inputs0 = booz_xform_inputs_from_state(
        state=vmec_run.state,
        static=vmec_run.static,
        indata=vmec_run.indata,
        signgs=int(vmec_run.signgs),
    )

    nyq_m = np.asarray(inputs0.xm_nyq)
    nyq_n = np.asarray(inputs0.xn_nyq)
    nfp_int = int(inputs0.nfp)
    mmax = int(np.max(nyq_m)) if nyq_m.size else 0
    nmax = int(np.max(np.abs(nyq_n))) // nfp_int if nyq_n.size else 0
    trig = vmec_trig_tables(
        ntheta=int(vmec_run.static.cfg.ntheta),
        nzeta=int(vmec_run.static.cfg.nzeta),
        nfp=nfp_int,
        mmax=mmax,
        nmax=nmax,
        lasym=bool(vmec_run.static.cfg.lasym),
        dtype=np.asarray(inputs0.rmnc).dtype,
        cache=True,
    )

    mboz_val = int(booz_kwargs.get("mboz") or (np.max(np.asarray(inputs0.xm)) + 1))
    nboz_val = int(
        booz_kwargs.get("nboz")
        or (np.max(np.abs(np.asarray(inputs0.xn))) // int(inputs0.nfp))
    )

    constants, grids = prepare_booz_xform_constants(
        nfp=int(inputs0.nfp),
        mboz=mboz_val,
        nboz=nboz_val,
        asym=bool(vmec_run.static.cfg.lasym),
        xm=np.asarray(inputs0.xm),
        xn=np.asarray(inputs0.xn),
        xm_nyq=np.asarray(inputs0.xm_nyq),
        xn_nyq=np.asarray(inputs0.xn_nyq),
    )

    ns_full = int(inputs0.rmnc.shape[0])
    s_half_full = jnp.asarray(0.5 * (vmec_run.static.s[:-1] + vmec_run.static.s[1:]))
    s_half_eval = jnp.concatenate([vmec_run.static.s[:1], s_half_full], axis=0)
    profiles_half = eval_profiles(vmec_run.indata, s_half_eval)
    flux = flux_profiles_from_indata(vmec_run.indata, vmec_run.static.s, signgs=int(vmec_run.signgs))
    if cfg.surfaces is None:
        surface_indices = None
        s_selected = s_half_full
    else:
        s_vals = list(np.asarray(s_half_full))
        surface_indices_list = []
        for val in cfg.surfaces:
            if isinstance(val, float) and 0.0 <= val <= 1.0:
                best = min(range(ns_full), key=lambda i: abs(s_vals[i] - val))
                surface_indices_list.append(best)
            else:
                surface_indices_list.append(int(val) - 1)
        surface_indices = jnp.asarray(surface_indices_list, dtype=jnp.int32)
        s_selected = jnp.take(s_half_full, surface_indices, axis=0)

    control = cfg.to_control()
    # Modes already filtered in the JAX pipeline; skip re-masking in Fourier sums.
    control = replace(control, max_m_mode=-1, max_n_mode=-1)

    # Precompute mode indices for JIT-safe slicing.
    xm_b_np = np.asarray(grids.xm_b)
    xn_b_np = np.asarray(grids.xn_b)
    max_m = int(cfg.max_m_mode) if cfg.max_m_mode > 0 else int(np.max(np.abs(xm_b_np)))
    max_n = int(cfg.max_n_mode) if cfg.max_n_mode > 0 else int(np.max(np.abs(xn_b_np)))
    mode_indices = np.where((np.abs(xm_b_np) <= max_m) & (np.abs(xn_b_np) <= max_n))[0]

    def _solve(state):
        inputs = booz_xform_inputs_from_state(
            state=state,
            static=vmec_run.static,
            indata=vmec_run.indata,
            signgs=int(vmec_run.signgs),
            trig=trig,
            flux=flux,
            profiles_half=profiles_half,
        )
        booz_out = booz_xform_jax_impl(
            rmnc=inputs.rmnc,
            zmns=inputs.zmns,
            lmns=inputs.lmns,
            bmnc=inputs.bmnc,
            bsubumnc=inputs.bsubumnc,
            bsubvmnc=inputs.bsubvmnc,
            iota=inputs.iota,
            xm=inputs.xm,
            xn=inputs.xn,
            xm_nyq=inputs.xm_nyq,
            xn_nyq=inputs.xn_nyq,
            constants=constants,
            grids=grids,
            bmns=inputs.bmns,
            bsubumns=inputs.bsubumns,
            bsubvmns=inputs.bsubvmns,
            surface_indices=surface_indices,
        )
        booz_out["s_b"] = s_selected
        booz_out["ns_b"] = ns_full
        if surface_indices is not None:
            booz_out["jlist"] = surface_indices + 1
        booz = booz_xform_to_boozerdata_jax(
            booz_out,
            max_m_mode=cfg.max_m_mode,
            max_n_mode=cfg.max_n_mode,
            nfp_override=int(inputs0.nfp),
            mode_indices=mode_indices,
        )
        return run_neo_from_boozer_jax(booz, control, skip_fourier_mask=True)

    if jit:
        _solve = jax.jit(_solve)
    return _solve

This callable is the preferred entrypoint for repeated JAX-native studies where the same static geometry setup is reused many times.

Reference crosswalk

For readers comparing NEO_JAX to the established STELLOPT implementation, the existing routine-level crosswalk remains available in Reference Crosswalk.