Geometry Inputs and Pipelines

NEO operates in Boozer coordinates, where the magnetic field can be written in contravariant and covariant forms,

\[\mathbf{B} = \nabla \psi \times \nabla \theta_B + \iota \nabla \phi_B \times \nabla \psi = I(\psi) \nabla \phi_B + G(\psi) \nabla \theta_B + B_\psi \nabla \psi.\]

This representation is the basis for the Boozer transform and the boozmn file format consumed by NEO_JAX. [2, 4]

File-based and in-memory geometry paths into NEO_JAX

Required data for NEO

NEO expects Fourier coefficients on each flux surface:

  • rmnc: cosine coefficients of cylindrical radius.

  • zmns: sine coefficients of vertical coordinate.

  • lmns: sine coefficients of the Boozer toroidal angle shift.

  • bmnc: cosine coefficients of magnetic field magnitude.

  • ixm, ixn: poloidal and toroidal mode numbers.

  • iota: rotational transform profile.

  • curr_pol and curr_tor: Boozer currents \(I\) and \(G\).

  • nfp: number of field periods.

Mapping from boozmn

The standard boozmn netCDF file (from BOOZ_XFORM) provides arrays such as rmnc_b, zmns_b, bmnc_b, pmns_b, ixm_b, ixn_b, iota_b, bvco_b, and buco_b. NEO maps these to its internal representation using

\[\lambda_{mn} = - \mathrm{pmns\_b}_{mn} \frac{n_{\mathrm{fp}}}{2\pi}.\]

The Boozer currents are mapped as curr_pol = bvco_b and curr_tor = buco_b. The file reader also computes the normalized toroidal flux coordinate

\[s = \frac{j - 1.5}{n_{s,b}-1},\]

which is the radial coordinate surfaced in the public API.

The main file-based reader is:

def read_boozmn(
    path: str | Path,
    *,
    max_m_mode: int = 0,
    max_n_mode: int = 0,
    fluxs_arr: Optional[Sequence[int]] = None,
    extension: str | None = None,
) -> BoozerData:
    """Read a boozmn netCDF file and return BoozerData.

    This function follows the packing conventions used in STELLOPT's
    read_boozer_mod, but only retains the surfaces requested.
    """
    _require_netcdf4()
    booz_path = resolve_boozmn_path(path, extension)

    with netCDF4.Dataset(booz_path) as ds:  # type: ignore[union-attr]
        nfp = int(ds.variables["nfp_b"][:])
        ns_b = int(ds.variables["ns_b"][:])
        mboz_b = int(ds.variables["mboz_b"][:])
        nboz_b = int(ds.variables["nboz_b"][:])

        ixm_b = np.array(ds.variables["ixm_b"][:], dtype=int)
        ixn_b = np.array(ds.variables["ixn_b"][:], dtype=int)

        iota_b = np.array(ds.variables["iota_b"][:], dtype=float)
        buco_b = np.array(ds.variables["buco_b"][:], dtype=float)
        bvco_b = np.array(ds.variables["bvco_b"][:], dtype=float)
        pres_b = np.array(ds.variables["pres_b"][:], dtype=float) if "pres_b" in ds.variables else None

        rmnc_raw = np.array(ds.variables["rmnc_b"][:], dtype=float)
        zmns_raw = np.array(ds.variables["zmns_b"][:], dtype=float)
        pmns_raw = np.array(ds.variables["pmns_b"][:], dtype=float)
        bmnc_raw = np.array(ds.variables["bmnc_b"][:], dtype=float)
        gmn_raw = np.array(ds.variables["gmn_b"][:], dtype=float) if "gmn_b" in ds.variables else None

        if "jlist" in ds.variables:
            jlist = np.array(ds.variables["jlist"][:], dtype=int)
        else:
            jlist = np.arange(1, rmnc_raw.shape[0] + 1, dtype=int)

    pack_len = rmnc_raw.shape[0]
    rmnc_pack = _transpose_if_needed(rmnc_raw, pack_len)
    zmns_pack = _transpose_if_needed(zmns_raw, pack_len)
    pmns_pack = _transpose_if_needed(pmns_raw, pack_len)
    bmnc_pack = _transpose_if_needed(bmnc_raw, pack_len)
    gmn_pack = _transpose_if_needed(gmn_raw, pack_len) if gmn_raw is not None else None

    max_m = max_m_mode if max_m_mode > 0 else mboz_b - 1
    max_n = max_n_mode if max_n_mode > 0 else nboz_b * nfp
    mode_mask = _select_modes(ixm_b, ixn_b, max_m, max_n)

    ixm = ixm_b[mode_mask]
    ixn = ixn_b[mode_mask]
    mode0 = np.where((ixm_b == 0) & (ixn_b == 0))[0]
    mode0_idx = int(mode0[0]) if len(mode0) else None

    pack_index = {int(surf): idx for idx, surf in enumerate(jlist)}

    if fluxs_arr is not None:
        surfaces = list(fluxs_arr)
    else:
        surfaces = list(jlist)

    rmnc = []
    zmns = []
    lmns = []
    bmnc = []
    es = []
    iota = []
    curr_pol = []
    curr_tor = []
    pprime = []
    sqrtg00 = []

    hs = 1.0 / (ns_b - 1)
    for surf in surfaces:
        pack_idx = pack_index.get(int(surf))
        if pack_idx is None:
            raise ValueError(f"Surface {surf} not found in boozmn jlist")

        rmnc.append(rmnc_pack[pack_idx, mode_mask])
        zmns.append(zmns_pack[pack_idx, mode_mask])
        lmns.append(-pmns_pack[pack_idx, mode_mask] * nfp / (2.0 * math.pi))
        bmnc.append(bmnc_pack[pack_idx, mode_mask])

        es.append((surf - 1.5) * hs)
        iota.append(iota_b[surf - 1])
        curr_pol.append(bvco_b[surf - 1])
        curr_tor.append(buco_b[surf - 1])
        if pres_b is not None and surf < ns_b:
            pprime.append((pres_b[surf] - pres_b[surf - 1]) / hs)
        else:
            pprime.append(0.0)
        if gmn_pack is not None and mode0_idx is not None:
            sqrtg00.append(float(gmn_pack[pack_idx, mode0_idx]))
        else:
            sqrtg00.append(0.0)

    return BoozerData(
        rmnc=np.asarray(rmnc),
        zmns=np.asarray(zmns),
        lmns=np.asarray(lmns),
        bmnc=np.asarray(bmnc),
        ixm=np.asarray(ixm),
        ixn=np.asarray(ixn),
        es=np.asarray(es),
        iota=np.asarray(iota),
        curr_pol=np.asarray(curr_pol),
        curr_tor=np.asarray(curr_tor),
        nfp=nfp,
        pprime=np.asarray(pprime),
        sqrtg00=np.asarray(sqrtg00),
    )

End-to-end JAX pipeline

NEO_JAX is designed to consume the outputs of vmec_jax and booz_xform_jax directly, avoiding intermediate files and enabling end-to-end differentiation. [3, 4]

When using this pipeline, ensure the following:

  • The Boozer transform uses the same field-period convention as NEO.

  • Mode truncation matches max_m_mode and max_n_mode from the control file.

  • The current profiles supplied to NEO are consistent with the Boozer convention.

NEO_JAX provides neo_jax.io.booz_xform_to_boozerdata() to convert arrays from a booz_xform-style object into the BoozerData container used by the solver, and neo_jax.run_neo() (or neo_jax.run_booz_xform()) to run the solver directly on that object with a high-level configuration.

When the Boozer transform returns JAX arrays (from booz_xform_jax.jax_api), use neo_jax.io.booz_xform_to_boozerdata_jax() to preserve device arrays and keep the pipeline differentiable.

For pipeline workflows, see neo_jax.run_boozer_to_neo() and neo_jax.run_vmec_boozer_neo() for convenience wrappers.

For a JAX-native VMEC state → Boozer adapter plus a JAX surface scan, use neo_jax.run_vmec_boozer_neo_jax(). This path avoids NumPy in the VMEC→Boozer interface and is suitable for autodiff experiments.

For repeated solves (e.g., optimization loops), build a reusable pipeline:

from neo_jax import NeoConfig, build_vmec_boozer_neo_jax

solver = build_vmec_boozer_neo_jax(
    run,
    booz_kwargs=dict(mboz=8, nboz=8),
    neo_config=NeoConfig(surfaces=[0.5]),
    jit=True,
)
outputs = solver(run.state)

Example: vmec_jax → booz_xform_jax → neo_jax

from neo_jax import NeoConfig, run_vmec_boozer_neo

config = NeoConfig(surfaces=[0.25, 0.5, 0.75], theta_n=32, phi_n=32)
results = run_vmec_boozer_neo(
    "path/to/input.vmec",
    vmec_kwargs=dict(max_iter=1, use_initial_guess=True, vmec_project=False),
    booz_kwargs=dict(mboz=8, nboz=8),
    neo_config=config,
)

Operational notes:

  • For accurate surface mapping, it is recommended to let the Boozer step compute all VMEC half-grid surfaces and use the NEO surface selection (NeoConfig.surfaces) to pick the subset.

  • The in-memory path avoids file I/O entirely.

  • The JAX-native path is the preferred route for repeated solves and differentiation experiments.