Source Guide¶
This page is a code-oriented map of NEO_JAX for readers who want to connect the theory and numerics pages to the implementation.
Module structure¶
The repository is organized into a small number of solver-facing modules:
Module |
Responsibility |
|---|---|
|
High-level public API such as |
|
User-facing configuration model. |
|
|
|
Fourier reconstruction and derived geometric quantities. |
|
Surface initialization, spline construction, and \(B_{\min}\)/\(B_{\max}\) refinement. |
|
Spline evaluation and Newton-based extremum refinement. |
|
Field-line RHS, RK4 stepping, trapped-particle bookkeeping, and the JAX scan backend. |
|
Surface loop orchestration, scaling, diagnostics, and result assembly. |
|
VMEC→Boozer→NEO helper workflows. |
Geometry loading¶
The boozmn reader is where the external geometry enters the solver:
def read_boozmn(
path: str | Path,
*,
max_m_mode: int = 0,
max_n_mode: int = 0,
fluxs_arr: Optional[Sequence[int]] = None,
extension: str | None = None,
) -> BoozerData:
"""Read a boozmn netCDF file and return BoozerData.
This function follows the packing conventions used in STELLOPT's
read_boozer_mod, but only retains the surfaces requested.
"""
_require_netcdf4()
booz_path = resolve_boozmn_path(path, extension)
with netCDF4.Dataset(booz_path) as ds: # type: ignore[union-attr]
nfp = int(ds.variables["nfp_b"][:])
ns_b = int(ds.variables["ns_b"][:])
mboz_b = int(ds.variables["mboz_b"][:])
nboz_b = int(ds.variables["nboz_b"][:])
ixm_b = np.array(ds.variables["ixm_b"][:], dtype=int)
ixn_b = np.array(ds.variables["ixn_b"][:], dtype=int)
iota_b = np.array(ds.variables["iota_b"][:], dtype=float)
buco_b = np.array(ds.variables["buco_b"][:], dtype=float)
bvco_b = np.array(ds.variables["bvco_b"][:], dtype=float)
pres_b = np.array(ds.variables["pres_b"][:], dtype=float) if "pres_b" in ds.variables else None
rmnc_raw = np.array(ds.variables["rmnc_b"][:], dtype=float)
zmns_raw = np.array(ds.variables["zmns_b"][:], dtype=float)
pmns_raw = np.array(ds.variables["pmns_b"][:], dtype=float)
bmnc_raw = np.array(ds.variables["bmnc_b"][:], dtype=float)
gmn_raw = np.array(ds.variables["gmn_b"][:], dtype=float) if "gmn_b" in ds.variables else None
if "jlist" in ds.variables:
jlist = np.array(ds.variables["jlist"][:], dtype=int)
else:
jlist = np.arange(1, rmnc_raw.shape[0] + 1, dtype=int)
pack_len = rmnc_raw.shape[0]
rmnc_pack = _transpose_if_needed(rmnc_raw, pack_len)
zmns_pack = _transpose_if_needed(zmns_raw, pack_len)
pmns_pack = _transpose_if_needed(pmns_raw, pack_len)
bmnc_pack = _transpose_if_needed(bmnc_raw, pack_len)
gmn_pack = _transpose_if_needed(gmn_raw, pack_len) if gmn_raw is not None else None
max_m = max_m_mode if max_m_mode > 0 else mboz_b - 1
max_n = max_n_mode if max_n_mode > 0 else nboz_b * nfp
mode_mask = _select_modes(ixm_b, ixn_b, max_m, max_n)
ixm = ixm_b[mode_mask]
ixn = ixn_b[mode_mask]
mode0 = np.where((ixm_b == 0) & (ixn_b == 0))[0]
mode0_idx = int(mode0[0]) if len(mode0) else None
pack_index = {int(surf): idx for idx, surf in enumerate(jlist)}
if fluxs_arr is not None:
surfaces = list(fluxs_arr)
else:
surfaces = list(jlist)
rmnc = []
zmns = []
lmns = []
bmnc = []
es = []
iota = []
curr_pol = []
curr_tor = []
pprime = []
sqrtg00 = []
hs = 1.0 / (ns_b - 1)
for surf in surfaces:
pack_idx = pack_index.get(int(surf))
if pack_idx is None:
raise ValueError(f"Surface {surf} not found in boozmn jlist")
rmnc.append(rmnc_pack[pack_idx, mode_mask])
zmns.append(zmns_pack[pack_idx, mode_mask])
lmns.append(-pmns_pack[pack_idx, mode_mask] * nfp / (2.0 * math.pi))
bmnc.append(bmnc_pack[pack_idx, mode_mask])
es.append((surf - 1.5) * hs)
iota.append(iota_b[surf - 1])
curr_pol.append(bvco_b[surf - 1])
curr_tor.append(buco_b[surf - 1])
if pres_b is not None and surf < ns_b:
pprime.append((pres_b[surf] - pres_b[surf - 1]) / hs)
else:
pprime.append(0.0)
if gmn_pack is not None and mode0_idx is not None:
sqrtg00.append(float(gmn_pack[pack_idx, mode0_idx]))
else:
sqrtg00.append(0.0)
return BoozerData(
rmnc=np.asarray(rmnc),
zmns=np.asarray(zmns),
lmns=np.asarray(lmns),
bmnc=np.asarray(bmnc),
ixm=np.asarray(ixm),
ixn=np.asarray(ixn),
es=np.asarray(es),
iota=np.asarray(iota),
curr_pol=np.asarray(curr_pol),
curr_tor=np.asarray(curr_tor),
nfp=nfp,
pprime=np.asarray(pprime),
sqrtg00=np.asarray(sqrtg00),
)
This function:
resolves the requested file path
loads the Boozer Fourier coefficients and current profiles
maps the
boozmnconvention to the internalneo_jax.BoozerDatacontainercomputes the normalized toroidal-flux coordinate
sused by the public API
Surface initialization¶
For each selected surface, NEO_JAX reconstructs the geometry, derives the metric quantities, builds the spline representation, and refines \(B_{\min}\) and \(B_{\max}\):
def init_surface(
theta_arr: Array,
phi_arr: Array,
coeffs: Dict[str, Array],
ixm: Array,
ixn: Array,
nfp: int,
max_m_mode: int,
max_n_mode: int,
curr_pol: Array,
curr_tor: Array,
iota: Array,
grid: Dict[str, float],
calc_cur: bool = False,
use_jax: bool = False,
skip_mask: bool = False,
) -> SurfaceData:
"""Initialize a single flux surface: Fourier sums, derived fields, splines, B min/max."""
fourier = fourier_sums(
theta_arr,
phi_arr,
coeffs["rmnc"],
coeffs["zmns"],
coeffs["lmns"],
coeffs["bmnc"],
ixm,
ixn,
nfp=nfp,
max_m_mode=max_m_mode,
max_n_mode=max_n_mode,
skip_mask=skip_mask,
lasym=bool(coeffs.get("lasym", False)),
rmns=coeffs.get("rmns"),
zmnc=coeffs.get("zmnc"),
lmnc=coeffs.get("lmnc"),
bmns=coeffs.get("bmns"),
)
derived = derived_quantities(fourier, curr_pol=curr_pol, curr_tor=curr_tor, iota=iota)
fields = {**fourier, **derived}
splines = build_splines(fields, grid["theta_int"], grid["phi_int"], grid["mt"], grid["mp"], calc_cur)
# Find initial min/max from grid, then refine with Newton.
# Fortran's MINLOC/MAXLOC traverse arrays in column-major order
# (theta index varies fastest). Use exact extrema and pick the
# first index in Fortran order when there are ties.
b = fields["b"]
b_max = jnp.max(b)
b_min = jnp.min(b)
if use_jax:
max_i, max_j = _select_extremum_index_jax(b, find_max=True)
min_i, min_j = _select_extremum_index_jax(b, find_max=False)
else:
max_i, max_j = _select_extremum_index(
b,
theta_arr,
phi_arr,
ixm,
ixn,
coeffs["bmnc"],
max_m_mode,
max_n_mode,
find_max=True,
)
min_i, min_j = _select_extremum_index(
b,
theta_arr,
phi_arr,
ixm,
ixn,
coeffs["bmnc"],
max_m_mode,
max_n_mode,
find_max=False,
)
theta_bmin = theta_arr[min_i]
phi_bmin = phi_arr[min_j]
theta_bmax = theta_arr[max_i]
phi_bmax = phi_arr[max_j]
theta_bmin, phi_bmin, _it_min, _err_min = neo_zeros2d(
theta_bmin, phi_bmin, 1.0e-10, 100, splines["b_spl"], grid
)
theta_bmax, phi_bmax, _it_max, _err_max = neo_zeros2d(
theta_bmax, phi_bmax, 1.0e-10, 100, splines["b_spl"], grid
)
# Evaluate B at refined points.
from .geometry import neo_eval
b_min, *_ = neo_eval(
theta_bmin, phi_bmin,
splines["b_spl"],
splines["g_spl"],
splines["k_spl"],
splines["p_spl"],
splines.get("q_spl"),
grid,
)
b_max, *_ = neo_eval(
theta_bmax, phi_bmax,
splines["b_spl"],
splines["g_spl"],
splines["k_spl"],
splines["p_spl"],
splines.get("q_spl"),
grid,
)
bmref = b_max
return SurfaceData(
b_min=b_min,
b_max=b_max,
theta_bmin=theta_bmin,
phi_bmin=phi_bmin,
theta_bmax=theta_bmax,
phi_bmax=phi_bmax,
bmref=bmref,
fields=fields,
splines=splines,
)
Field-line RHS¶
The core continuous model enters through neo_jax.integrate.rhs_bo1():
def rhs_bo1(phi: Array, y: Array, state: RhsState, env: RhsEnv) -> Tuple[Array, RhsState]:
"""Right-hand side for the field-line ODE (port of rhs_bo1.f90)."""
theta = y[0]
bmod, gval, geodcu, pardeb, _qval = neo_eval(
theta,
phi,
env.splines["b_spl"],
env.splines["g_spl"],
env.splines["k_spl"],
env.splines["p_spl"],
env.splines.get("q_spl"),
env.grid,
)
bmodm2 = 1.0 / (bmod * bmod)
bmodm3 = bmodm2 / bmod
bra = bmod / env.bmod0
ipass = jnp.where((pardeb * state.pard0 <= 0) & (pardeb > 0), 1, 0).astype(state.isw.dtype)
ipmax = jnp.where(
(state.ipmax == 0) & (pardeb * state.pard0 <= 0) & (pardeb < 0),
1,
state.ipmax,
).astype(state.isw.dtype)
pard0 = pardeb
dery = jnp.zeros_like(y)
dery = dery.at[0].set(env.iota)
dery = dery.at[1].set(bmodm2)
dery = dery.at[2].set(bmodm2 * gval)
dery = dery.at[3].set(geodcu * bmodm3)
eta_pos = env.eta > 0
eta_safe = jnp.where(eta_pos, env.eta, jnp.asarray(1.0, dtype=env.eta.dtype))
inv_eta = jnp.where(eta_pos, 1.0 / eta_safe, 0.0)
sqeta = jnp.sqrt(eta_safe)
inv_sqeta = jnp.where(eta_pos, 1.0 / sqeta, 0.0)
subsq = 1.0 - bra * inv_eta
mask = (subsq > 0) & eta_pos
safe_subsq = jnp.where(mask, subsq, 0.0)
sq = jnp.sqrt(safe_subsq) * bmodm2
p_i = jnp.where(mask, sq, 0.0)
p_h = jnp.where(mask, sq * (4.0 / bra - inv_eta) * geodcu * inv_sqeta, 0.0)
# Update particle state
one_i = jnp.array(1, dtype=state.isw.dtype)
two_i = jnp.array(2, dtype=state.isw.dtype)
zero_i = jnp.array(0, dtype=state.isw.dtype)
isw = jnp.where(mask, one_i, jnp.where(state.isw == 1, two_i, jnp.where(state.isw == 2, two_i, zero_i)))
icount = state.icount + mask.astype(state.icount.dtype)
ipa = state.ipa + ipass * mask.astype(state.ipa.dtype)
dery = dery.at[NPQ : NPQ + env.eta.shape[0]].set(p_i)
dery = dery.at[NPQ + env.eta.shape[0] : NPQ + 2 * env.eta.shape[0]].set(p_h)
new_state = RhsState(isw=isw, ipa=ipa, icount=icount, ipmax=ipmax, pard0=pard0)
return dery, new_state
This is where the state vector \((\theta, y_2, y_3, y_4, I_j, H_j)\) is advanced and where the trapped particle masks are updated.
JAX scan backend¶
The compiled backend lives in neo_jax.integrate.flint_bo_jax():
def flint_bo_jax(
surface,
params: FlintParams,
env: RhsEnv,
nfp: int,
rt0: float | None = None,
*,
Rmajor: float | None = None,
diagnostic_callback=None,
diagnostic_trap_callback=None,
diagnostic_snapshot: tuple[int, int] | None = None,
diagnostic_snapshot_callback=None,
convergence_callback=None,
convergence_period_callback=None,
convergence_step_callback=None,
convergence_reset_callback=None,
strict_parity: bool = False,
skip_rational_correction: bool = False,
):
"""JAX-friendly integration loop with rational-surface correction."""
if rt0 is None:
if Rmajor is None:
raise ValueError("Either rt0 or Rmajor must be provided.")
rt0 = float(Rmajor)
elif Rmajor is not None and float(Rmajor) != float(rt0):
raise ValueError("rt0 and Rmajor must match if both are provided.")
npart = params.npart
multra = params.multra
ndim = NPQ + 2 * npart
# Particle grids
etamin = surface.b_min / surface.bmref
etamax = surface.b_max / surface.bmref
heta = (etamax - etamin) / (npart - 1)
etamin = etamin + heta / 2.0
eta = etamin + heta * jnp.arange(npart)
env = RhsEnv(
splines=env.splines,
grid=env.grid,
eta=eta,
bmod0=surface.bmref,
iota=env.iota,
curr_pol=env.curr_pol,
curr_tor=env.curr_tor,
)
coeps = jnp.pi * rt0 * rt0 * heta / (8.0 * jnp.sqrt(2.0))
if env.curr_pol is None or env.curr_tor is None:
j_iota_i = 0.0
else:
j_iota_i = env.curr_pol + env.iota * env.curr_tor
y = jnp.zeros(ndim)
y = y.at[0].set(surface.theta_bmax)
phi0 = surface.phi_bmax
phi = phi0
state = RhsState(
isw=jnp.zeros(npart, dtype=jnp.int32),
ipa=jnp.zeros(npart, dtype=jnp.int32),
icount=jnp.zeros(npart, dtype=jnp.int32),
ipmax=jnp.array(0, dtype=jnp.int32),
pard0=jnp.array(0.0),
)
_bmod, _gval, _geodcu, pard0, _qval = neo_eval(
surface.theta_bmax,
surface.phi_bmax,
env.splines["b_spl"],
env.splines["g_spl"],
env.splines["k_spl"],
env.splines["p_spl"],
env.splines.get("q_spl"),
env.grid,
)
state = RhsState(state.isw, state.ipa, state.icount, state.ipmax, pard0)
iswst = jnp.zeros(npart, dtype=jnp.int32)
bigint = jnp.zeros(multra)
adimax = jnp.array(0.0)
aditot = jnp.array(0.0)
hphi = 2.0 * jnp.pi / (params.nstep_per * nfp)
theta0 = surface.theta_bmax
theta_d_min = 2.0 * jnp.pi
n_iota = jnp.array(1, dtype=jnp.int32)
m_iota = jnp.array(1, dtype=jnp.int32)
iota_bar_fp = jnp.array(0.0)
nstep_max_c = jnp.array(params.nstep_max, dtype=jnp.int32)
hit_rat = jnp.array(0, dtype=jnp.int32)
exist_first_ratfl = jnp.array(0, dtype=jnp.int32)
nfp_rat = jnp.array(0, dtype=jnp.int32)
nfl_rat = jnp.array(0, dtype=jnp.int32)
delta_theta_rat = jnp.array(0.0)
n_gap = jnp.array(0, dtype=jnp.int32)
stop = jnp.array(False)
n = jnp.array(0, dtype=jnp.int32)
if diagnostic_snapshot is None:
snap_n = None
snap_j = None
else:
snap_n = jnp.array(diagnostic_snapshot[0], dtype=jnp.int32)
snap_j = jnp.array(diagnostic_snapshot[1], dtype=jnp.int32)
def integrate_period(carry, emit_diag: bool, step_index):
phi, y, state, iswst, bigint, adimax, aditot = carry
def rhs_inline(phi_local: Array, y_local: Array, state_local: RhsState) -> Tuple[Array, RhsState]:
"""Inline RHS evaluation to help XLA fuse neo_eval + RK4."""
theta = y_local[0]
bmod, gval, geodcu, pardeb, _qval = neo_eval(
theta,
phi_local,
env.splines["b_spl"],
env.splines["g_spl"],
env.splines["k_spl"],
env.splines["p_spl"],
env.splines.get("q_spl"),
env.grid,
)
bmodm2 = 1.0 / (bmod * bmod)
bmodm3 = bmodm2 / bmod
bra = bmod / env.bmod0
ipass = jnp.where(
(pardeb * state_local.pard0 <= 0) & (pardeb > 0), 1, 0
).astype(state_local.isw.dtype)
ipmax = jnp.where(
(state_local.ipmax == 0) & (pardeb * state_local.pard0 <= 0) & (pardeb < 0),
1,
state_local.ipmax,
).astype(state_local.isw.dtype)
pard0 = pardeb
dery = jnp.zeros_like(y_local)
dery = dery.at[0].set(env.iota)
dery = dery.at[1].set(bmodm2)
dery = dery.at[2].set(bmodm2 * gval)
dery = dery.at[3].set(geodcu * bmodm3)
eta_pos = env.eta > 0
eta_safe = jnp.where(eta_pos, env.eta, jnp.asarray(1.0, dtype=env.eta.dtype))
inv_eta = jnp.where(eta_pos, 1.0 / eta_safe, 0.0)
sqeta = jnp.sqrt(eta_safe)
inv_sqeta = jnp.where(eta_pos, 1.0 / sqeta, 0.0)
subsq = 1.0 - bra * inv_eta
mask = (subsq > 0) & eta_pos
safe_subsq = jnp.where(mask, subsq, 0.0)
sq = jnp.sqrt(safe_subsq) * bmodm2
p_i = jnp.where(mask, sq, 0.0)
p_h = jnp.where(mask, sq * (4.0 / bra - inv_eta) * geodcu * inv_sqeta, 0.0)
one_i = jnp.array(1, dtype=state_local.isw.dtype)
two_i = jnp.array(2, dtype=state_local.isw.dtype)
zero_i = jnp.array(0, dtype=state_local.isw.dtype)
isw = jnp.where(
mask,
one_i,
jnp.where(
state_local.isw == 1,
two_i,
jnp.where(state_local.isw == 2, two_i, zero_i),
),
)
icount = state_local.icount + mask.astype(state_local.icount.dtype)
ipa = state_local.ipa + ipass * mask.astype(state_local.ipa.dtype)
dery = dery.at[NPQ : NPQ + env.eta.shape[0]].set(p_i)
dery = dery.at[NPQ + env.eta.shape[0] : NPQ + 2 * env.eta.shape[0]].set(p_h)
new_state = RhsState(isw=isw, ipa=ipa, icount=icount, ipmax=ipmax, pard0=pard0)
return dery, new_state
def rk4_step_inline(
phi_local: Array, y_local: Array, state_local: RhsState
) -> Tuple[Array, Array, RhsState]:
"""Inline RK4 to increase fusion with neo_eval."""
hh = hphi / 2.0
h6 = hphi / 6.0
k1, state1 = rhs_inline(phi_local, y_local, state_local)
y1 = y_local + hh * k1
k2, state2 = rhs_inline(phi_local + hh, y1, state1)
y2 = y_local + hh * k2
k3, state3 = rhs_inline(phi_local + hh, y2, state2)
y3 = y_local + hphi * k3
k4, state4 = rhs_inline(phi_local + hphi, y3, state3)
y_new = y_local + h6 * (k1 + k4 + 2.0 * (k2 + k3))
phi_new = phi_local + hphi
return phi_new, y_new, state4
def inner_step(j, inner):
phi, y, state, iswst, bigint, adimax, aditot = inner
if strict_parity:
phi, y, state = rk4_step(phi, y, state, env, hphi)
else:
phi, y, state = rk4_step_inline(phi, y, state)
p_i = y[NPQ : NPQ + npart]
p_h = y[NPQ + npart : NPQ + 2 * npart]
if diagnostic_callback is not None and emit_diag:
mask2 = state.isw == 2
event_mask = mask2 & (iswst == 1)
safe_pi = jnp.where(p_i == 0, jnp.array(1.0, dtype=p_i.dtype), p_i)
add_on = (p_h * p_h) / safe_pi * iswst
def _emit(_):
jax.debug.callback(diagnostic_callback, event_mask, state.icount, state.ipa, add_on, ordered=True)
return None
_ = jax.lax.cond(jnp.any(event_mask), _emit, lambda _: None, operand=None)
if diagnostic_trap_callback is not None and emit_diag:
mask2 = state.isw == 2
event_mask = mask2 & (iswst == 1)
def _emit_trap(_):
jax.debug.callback(
diagnostic_trap_callback,
event_mask,
state.isw,
iswst,
p_i,
p_h,
state.icount,
state.ipa,
phi,
j,
step_index,
ordered=True,
)
return None
_ = jax.lax.cond(jnp.any(event_mask), _emit_trap, lambda _: None, operand=None)
if diagnostic_snapshot_callback is not None and emit_diag and snap_n is not None and snap_j is not None:
def _emit_snap(_):
jax.debug.callback(
diagnostic_snapshot_callback,
state.isw,
iswst,
p_i,
p_h,
state.icount,
state.ipa,
phi,
j,
step_index,
ordered=True,
)
return None
snap_cond = (step_index == snap_n) & (j == snap_j)
_ = jax.lax.cond(snap_cond, _emit_snap, lambda _: None, operand=None)
if convergence_step_callback is not None:
jax.debug.callback(
convergence_step_callback,
step_index + 1,
j + 1,
state.ipmax,
state.isw,
state.ipa,
p_i,
ordered=True,
)
if strict_parity:
state, iswst, p_i, p_h, bigint, adimax = _process_trapped(
state, iswst, p_i, p_h, bigint, adimax, multra
)
else:
# Inline trapped-particle update to improve fusion in scan body.
mask2 = state.isw == 2
m_cl = jnp.clip(state.ipa, 1, multra).astype(state.ipa.dtype)
def body(i, carry):
bigint_acc, adimax_acc = carry
def add_fn(carry):
bigint_acc, adimax_acc = carry
safe_pi = jnp.where(p_i[i] == 0, jnp.array(1.0, dtype=p_i.dtype), p_i[i])
add_on = (p_h[i] * p_h[i]) / safe_pi * iswst[i]
idx = m_cl[i] - 1
bigint_acc = bigint_acc.at[idx].add(add_on)
adimax_acc = jnp.where(state.ipa[i] == 1, p_i[i], adimax_acc)
return bigint_acc, adimax_acc
return jax.lax.cond(mask2[i], add_fn, lambda c: c, carry)
bigint, adimax = jax.lax.fori_loop(0, p_i.shape[0], body, (bigint, adimax))
iswst = jnp.where(mask2, 1, iswst)
p_h = jnp.where(mask2, 0.0, p_h)
p_i = jnp.where(mask2, 0.0, p_i)
zero_int = jnp.zeros_like(state.isw)
isw = jnp.where(mask2, zero_int, state.isw)
icount = jnp.where(mask2, zero_int, state.icount)
ipa = jnp.where(mask2, zero_int, state.ipa)
state = RhsState(isw, ipa, icount, state.ipmax, state.pard0)
y = y.at[NPQ : NPQ + npart].set(p_i)
y = y.at[NPQ + npart : NPQ + 2 * npart].set(p_h)
aditot = jnp.where(state.ipmax == 1, aditot + adimax, aditot)
ipmax = jnp.where(
state.ipmax == 1, jnp.array(0, dtype=state.ipmax.dtype), state.ipmax
)
state = RhsState(state.isw, state.ipa, state.icount, ipmax, state.pard0)
return (phi, y, state, iswst, bigint, adimax, aditot)
return jax.lax.fori_loop(0, params.nstep_per, inner_step, carry)
def convergence_row(n_val, bigint_val, y_val, aditot_val):
epspar_check = coeps * bigint_val * y_val[1] / (y_val[2] * y_val[2])
epstot_check = jnp.sum(epspar_check)
return jnp.asarray(
[
n_val.astype(y_val.dtype),
epstot_check,
y_val[3],
y_val[NPQ + npart - 1] / y_val[1],
aditot_val / y_val[1],
]
)
def emit_convergence(n_val, bigint_val, y_val, aditot_val):
if convergence_callback is not None:
row = convergence_row(n_val, bigint_val, y_val, aditot_val)
jax.debug.callback(convergence_callback, row, ordered=True)
if convergence_period_callback is not None:
epspar_check = coeps * bigint_val * y_val[1] / (y_val[2] * y_val[2])
epstot_check = jnp.sum(epspar_check)
period_row = jnp.asarray(
[
n_val.astype(y_val.dtype),
epstot_check,
y_val[3],
y_val[NPQ + npart - 1] / y_val[1],
y_val[1],
]
)
jax.debug.callback(convergence_period_callback, period_row, ordered=True)
def update_theta_min(n_val, theta, theta_d_min, n_iota, m_iota, iota_bar_fp):
twopi = 2.0 * jnp.pi
def body(_):
theta_rs = theta - theta0
iota_bar_fp_new = jnp.where(n_val == 1, theta_rs / twopi, iota_bar_fp)
m = jnp.floor(theta_rs / twopi)
theta_rs_mod = theta_rs - m * twopi
theta_d = jnp.where(theta_rs_mod <= jnp.pi, theta_rs_mod, theta_rs_mod - twopi)
update = jnp.abs(theta_d) < jnp.abs(theta_d_min)
theta_d_min_new = jnp.where(update, theta_d, theta_d_min)
n_iota_new = jnp.where(update, n_val, n_iota)
m_iota_new = jnp.where(
update, jnp.where(theta_d >= 0, m.astype(jnp.int32), (m + 1).astype(jnp.int32)), m_iota
)
return theta_d_min_new, n_iota_new, m_iota_new, iota_bar_fp_new
return jax.lax.cond(
n_val <= params.nstep_min, body, lambda _: (theta_d_min, n_iota, m_iota, iota_bar_fp), operand=None
)
def update_rational(
n_val,
theta_d_min,
n_iota,
iota_bar_fp,
nstep_max_c,
hit_rat,
exist_first_ratfl,
nfp_rat,
nfl_rat,
delta_theta_rat,
n_gap,
):
twopi = 2.0 * jnp.pi
def body(_):
theta_d_min_safe = jnp.where(theta_d_min == 0.0, 1.0e-12, theta_d_min)
theta_gap = twopi / n_iota
n_gap_new = n_iota * jnp.floor(jnp.abs(theta_gap / theta_d_min_safe)).astype(jnp.int32)
nstep_max_c_new = jnp.where(
n_gap_new > params.nstep_min,
n_gap_new,
n_gap_new * jnp.ceil(params.nstep_min / n_gap_new).astype(jnp.int32),
)
hit_rat_new = jnp.where(nstep_max_c_new > params.nstep_max, 1, 0).astype(jnp.int32)
nfp_rat_new = jnp.where(
(hit_rat_new == 1) & (iota_bar_fp != 0.0),
jnp.ceil(1.0 / params.acc_req / iota_bar_fp).astype(jnp.int32),
0,
)
nfp_rat_new = jnp.where(
(hit_rat_new == 1) & (nfp_rat_new % n_iota != 0),
nfp_rat_new + n_iota - (nfp_rat_new % n_iota),
nfp_rat_new,
)
exist_first_ratfl_new = jnp.where(nfp_rat_new >= params.nstep_min, 1, 0).astype(jnp.int32)
nstep_max_c_new = jnp.where(exist_first_ratfl_new == 1, nfp_rat_new, nstep_max_c_new)
nfl_rat_new = jnp.where(
hit_rat_new == 1,
jnp.ceil(params.no_bins / n_iota).astype(jnp.int32),
nfl_rat,
)
delta_theta_rat_new = jnp.where(
hit_rat_new == 1,
theta_gap / (nfl_rat_new + 1),
delta_theta_rat,
)
hit_rat_new = jnp.where(params.calc_nstep_max == 1, 0, hit_rat_new).astype(jnp.int32)
return (
nstep_max_c_new,
hit_rat_new,
exist_first_ratfl_new,
nfp_rat_new,
nfl_rat_new,
delta_theta_rat_new,
n_gap_new,
)
return jax.lax.cond(
n_val == params.nstep_min,
body,
lambda _: (nstep_max_c, hit_rat, exist_first_ratfl, nfp_rat, nfl_rat, delta_theta_rat, n_gap),
operand=None,
)
def scan_body(carry, _):
(
phi,
y,
state,
iswst,
bigint,
adimax,
aditot,
n,
theta_d_min,
n_iota,
m_iota,
iota_bar_fp,
nstep_max_c,
hit_rat,
exist_first_ratfl,
nfp_rat,
nfl_rat,
delta_theta_rat,
n_gap,
stop,
) = carry
def do_step(_carry):
(
phi,
y,
state,
iswst,
bigint,
adimax,
aditot,
n,
theta_d_min,
n_iota,
m_iota,
iota_bar_fp,
nstep_max_c,
hit_rat,
exist_first_ratfl,
nfp_rat,
nfl_rat,
delta_theta_rat,
n_gap,
stop,
) = _carry
phi, y, state, iswst, bigint, adimax, aditot = integrate_period(
(phi, y, state, iswst, bigint, adimax, aditot),
True,
n,
)
n_new = n + 1
emit_convergence(n_new.astype(y.dtype), bigint, y, aditot)
theta = y[0]
theta_d_min_new, n_iota_new, m_iota_new, iota_bar_fp_new = update_theta_min(
n_new, theta, theta_d_min, n_iota, m_iota, iota_bar_fp
)
(
nstep_max_c_new,
hit_rat_new,
exist_first_ratfl_new,
nfp_rat_new,
nfl_rat_new,
delta_theta_rat_new,
n_gap_new,
) = update_rational(
n_new,
theta_d_min_new,
n_iota_new,
iota_bar_fp_new,
nstep_max_c,
hit_rat,
exist_first_ratfl,
nfp_rat,
nfl_rat,
delta_theta_rat,
n_gap,
)
stop_new = jnp.where(
(params.calc_nstep_max == 0) & (n_new == nstep_max_c_new),
True,
False,
)
stop_new = jnp.where(
(hit_rat_new == 1) & (exist_first_ratfl_new == 0), True, stop_new
)
return (
phi,
y,
state,
iswst,
bigint,
adimax,
aditot,
n_new,
theta_d_min_new,
n_iota_new,
m_iota_new,
iota_bar_fp_new,
nstep_max_c_new,
hit_rat_new,
exist_first_ratfl_new,
nfp_rat_new,
nfl_rat_new,
delta_theta_rat_new,
n_gap_new,
stop_new,
)
return jax.lax.cond(stop, lambda c: c, do_step, carry), None
init_carry = (
phi,
y,
state,
iswst,
bigint,
adimax,
aditot,
n,
theta_d_min,
n_iota,
m_iota,
iota_bar_fp,
nstep_max_c,
hit_rat,
exist_first_ratfl,
nfp_rat,
nfl_rat,
delta_theta_rat,
n_gap,
stop,
)
final_carry, _ = jax.lax.scan(scan_body, init_carry, xs=None, length=params.nstep_max)
(
phi,
y,
state,
iswst,
bigint,
adimax,
aditot,
n,
theta_d_min,
n_iota,
m_iota,
iota_bar_fp,
nstep_max_c,
hit_rat,
exist_first_ratfl,
nfp_rat,
nfl_rat,
delta_theta_rat,
n_gap,
stop,
) = final_carry
y2 = y[1]
y3 = y[2]
y4 = y[3]
y3npart = y[NPQ + npart - 1]
def rational_correction(_):
if convergence_reset_callback is not None:
def _reset(_):
jax.debug.callback(convergence_reset_callback, ordered=True)
return None
_ = jax.lax.cond(exist_first_ratfl == 0, _reset, lambda _: None, operand=None)
def reset_accumulators(_):
zero_bigint = jnp.zeros_like(bigint)
return zero_bigint, jnp.array(0.0), jnp.array(0.0), jnp.array(0.0), jnp.array(0.0), jnp.array(0.0)
def keep_accumulators(_):
return bigint, aditot, y2, y3, y4, y3npart
bigint0, aditot0, y20, y30, y40, y3npart0 = jax.lax.cond(
exist_first_ratfl == 0, reset_accumulators, keep_accumulators, operand=None
)
nfl_count = jnp.maximum(nfl_rat + 1 - exist_first_ratfl, 0).astype(jnp.int32)
def nfl_cond(carry):
idx, *_ = carry
return idx < nfl_count
def nfl_body(carry):
idx, bigint_acc, aditot_acc, y2_acc, y3_acc, y4_acc, y3npart_acc = carry
nfl = idx + exist_first_ratfl
phi_local = phi0
y_local = jnp.zeros(ndim)
y_local = y_local.at[0].set(theta0 + nfl * delta_theta_rat)
state_local = RhsState(
isw=jnp.zeros(npart, dtype=jnp.int32),
ipa=jnp.zeros(npart, dtype=jnp.int32),
icount=jnp.zeros(npart, dtype=jnp.int32),
ipmax=jnp.array(0, dtype=jnp.int32),
pard0=jnp.array(0.0),
)
iswst_local = jnp.zeros(npart, dtype=jnp.int32)
bigint_s = jnp.zeros(multra)
adimax_s = jnp.array(0.0)
aditot_s = jnp.array(0.0)
def n_cond(ncarry):
n_idx, *_ = ncarry
return n_idx < nfp_rat
def n_body(ncarry):
n_idx, phi_l, y_l, state_l, iswst_l, bigint_s, adimax_s, aditot_s = ncarry
_bmod, _gval, _geodcu, pard0, _qval = neo_eval(
y_l[0],
phi_l,
env.splines["b_spl"],
env.splines["g_spl"],
env.splines["k_spl"],
env.splines["p_spl"],
env.splines.get("q_spl"),
env.grid,
)
state_l = RhsState(state_l.isw, state_l.ipa, state_l.icount, state_l.ipmax, pard0)
phi_l, y_l, state_l, iswst_l, bigint_s, adimax_s, aditot_s = integrate_period(
(phi_l, y_l, state_l, iswst_l, bigint_s, adimax_s, aditot_s),
False,
n_idx,
)
emit_convergence((nfl * nfp_rat + n_idx + 1).astype(y_l.dtype), bigint_s, y_l, aditot_s)
return (n_idx + 1, phi_l, y_l, state_l, iswst_l, bigint_s, adimax_s, aditot_s)
n_init = (
jnp.array(0, dtype=jnp.int32),
phi_local,
y_local,
state_local,
iswst_local,
bigint_s,
adimax_s,
aditot_s,
)
n_final = jax.lax.while_loop(n_cond, n_body, n_init)
_, phi_local, y_local, state_local, iswst_local, bigint_s, adimax_s, aditot_s = n_final
y2_s = y_local[1]
y3_s = y_local[2]
y4_s = y_local[3]
y3npart_s = y_local[NPQ + npart - 1]
return (
idx + 1,
bigint_acc + bigint_s,
aditot_acc + aditot_s,
y2_acc + y2_s,
y3_acc + y3_s,
y4_acc + y4_s,
y3npart_acc + y3npart_s,
)
nfl_init = (jnp.array(0, dtype=jnp.int32), bigint0, aditot0, y20, y30, y40, y3npart0)
nfl_final = jax.lax.while_loop(nfl_cond, nfl_body, nfl_init)
_, bigint_out, aditot_out, y2_out, y3_out, y4_out, y3npart_out = nfl_final
return bigint_out, aditot_out, y2_out, y3_out, y4_out, y3npart_out
def rational_skip(_):
return bigint, aditot, y2, y3, y4, y3npart
do_rational = (hit_rat == 1) & (nfp_rat > 0) & jnp.logical_not(jnp.asarray(skip_rational_correction))
bigint, aditot, y2, y3, y4, y3npart = jax.lax.cond(
do_rational, rational_correction, rational_skip, operand=None
)
nintfp = n
epspar = jnp.zeros(multra)
epstot = jnp.array(0.0)
for m_cl in range(1, multra + 1):
epspar = epspar.at[m_cl - 1].set(coeps * bigint[m_cl - 1] * y2 / (y3 * y3))
epstot = epstot + epspar[m_cl - 1]
ctrone = aditot / y2
ctrtot = y3npart / y2
bareph = (jnp.pi * ctrone) ** 2 / 8.0
barept = (jnp.pi * ctrtot) ** 2 / 8.0
drdpsi = y2 / y3
yps = y4 * j_iota_i
return {
"epspar": epspar,
"epstot": epstot,
"ctrone": ctrone,
"ctrtot": ctrtot,
"bareph": bareph,
"barept": barept,
"drdpsi": drdpsi,
"yps": yps,
"y2": y2,
"y3": y3,
"y4": y4,
"y3npart": y3npart,
"bigint": bigint,
"nintfp": nintfp,
"hit_rat": hit_rat,
"n_iota": n_iota,
"m_iota": m_iota,
"n_gap": n_gap,
"final_n": jnp.where(hit_rat == 1, nfp_rat * (nfl_rat + 1), nintfp),
"nfp_rat": nfp_rat,
"nfl_rat": nfl_rat,
}
This backend keeps the dominant loops on device, which is the basis for JIT reuse and batched surface evaluation.
Public configuration model¶
The public configuration surface is intentionally compact:
@dataclass(frozen=True)
class NeoConfig:
"""High-level configuration for NEO_JAX runs.
Surface selections may be specified as:
- Integers (1-based NEO surface indices), or
- Floats in [0, 1], interpreted as normalized toroidal flux ``s``.
``max_rational_field_periods`` is a safeguard for near-zero-iota surfaces.
Set it to ``0`` to disable the guard explicitly.
``rational_surface_policy`` controls what happens if the estimated
rational-surface correction exceeds that limit:
- ``"error"``: fail fast with a detailed diagnostic (default)
- ``"approximate"``: skip the expensive rational-surface correction and
return the base integration result with explicit diagnostics
"""
surfaces: Sequence[int | float] | None = None
theta_n: int = 64
phi_n: int = 64
max_m_mode: int = 0
max_n_mode: int = 0
npart: int = 40
multra: int = 2
acc_req: float = 0.02
no_bins: int = 50
nstep_per: int = 20
nstep_min: int = 200
nstep_max: int = 500
calc_nstep_max: int = 0
max_rational_field_periods: Optional[int] = 100_000
rational_surface_policy: str = "error"
ref_swi: int = 2
write_progress: bool = False
write_diagnostic: bool = False
@classmethod
def from_control(cls, control: ControlParams) -> "NeoConfig":
return cls(
surfaces=control.fluxs_arr,
theta_n=control.theta_n,
phi_n=control.phi_n,
max_m_mode=control.max_m_mode,
max_n_mode=control.max_n_mode,
npart=control.npart,
multra=control.multra,
acc_req=control.acc_req,
no_bins=control.no_bins,
nstep_per=control.nstep_per,
nstep_min=control.nstep_min,
nstep_max=control.nstep_max,
calc_nstep_max=control.calc_nstep_max,
ref_swi=control.ref_swi,
write_progress=bool(control.write_progress),
write_diagnostic=bool(control.write_diagnostic),
)
def to_control(self, *, in_file: str = "boozmn", out_file: str = "neo_out") -> ControlParams:
"""Convert to a ControlParams object (for CLI compatibility)."""
return ControlParams(
in_file=in_file,
out_file=out_file,
fluxs_arr=list(self.surfaces) if self.surfaces is not None else None,
theta_n=self.theta_n,
phi_n=self.phi_n,
max_m_mode=self.max_m_mode,
max_n_mode=self.max_n_mode,
npart=self.npart,
multra=self.multra,
acc_req=self.acc_req,
no_bins=self.no_bins,
nstep_per=self.nstep_per,
nstep_min=self.nstep_min,
nstep_max=self.nstep_max,
calc_nstep_max=self.calc_nstep_max,
eout_swi=1,
lab_swi=0,
inp_swi=0,
ref_swi=self.ref_swi,
write_progress=int(self.write_progress),
write_output_files=0,
spline_test=0,
write_integrate=0,
write_diagnostic=int(self.write_diagnostic),
calc_cur=0,
cur_file="neo_cur",
npart_cur=0,
alpha_cur=0.0,
write_cur_inte=0,
)
For user-level guidance on these fields, see Configuration and Runtime Controls.
Pipeline entrypoints¶
The reusable VMEC→Boozer→NEO path is built in neo_jax.pipeline:
def build_vmec_boozer_neo_jax(
vmec_run: Any,
*,
booz_kwargs: dict | None = None,
neo_config: NeoConfig | None = None,
jit: bool = True,
):
"""Return a callable `solve(state)` for the JAX-native VMEC→Boozer→NEO path.
This precomputes Boozer constants and surface selections so the returned
function is suitable for repeated calls (and optional JIT).
"""
booz_kwargs = booz_kwargs or {}
cfg = neo_config or NeoConfig()
try:
import jax
import jax.numpy as jnp
from booz_xform_jax.jax_api import prepare_booz_xform_constants, booz_xform_jax_impl
except ImportError as exc: # pragma: no cover
raise ImportError("booz_xform_jax is required for build_vmec_boozer_neo_jax") from exc
try:
from vmec_jax.booz_input import booz_xform_inputs_from_state
from vmec_jax.energy import flux_profiles_from_indata
from vmec_jax.profiles import eval_profiles
from vmec_jax.vmec_tomnsp import vmec_trig_tables
except ImportError as exc: # pragma: no cover
raise ImportError("vmec_jax with booz_input is required for build_vmec_boozer_neo_jax") from exc
from .driver import run_neo_from_boozer_jax
from .io import booz_xform_to_boozerdata_jax
# Precompute static Boozer constants from the current state.
inputs0 = booz_xform_inputs_from_state(
state=vmec_run.state,
static=vmec_run.static,
indata=vmec_run.indata,
signgs=int(vmec_run.signgs),
)
nyq_m = np.asarray(inputs0.xm_nyq)
nyq_n = np.asarray(inputs0.xn_nyq)
nfp_int = int(inputs0.nfp)
mmax = int(np.max(nyq_m)) if nyq_m.size else 0
nmax = int(np.max(np.abs(nyq_n))) // nfp_int if nyq_n.size else 0
trig = vmec_trig_tables(
ntheta=int(vmec_run.static.cfg.ntheta),
nzeta=int(vmec_run.static.cfg.nzeta),
nfp=nfp_int,
mmax=mmax,
nmax=nmax,
lasym=bool(vmec_run.static.cfg.lasym),
dtype=np.asarray(inputs0.rmnc).dtype,
cache=True,
)
mboz_val = int(booz_kwargs.get("mboz") or (np.max(np.asarray(inputs0.xm)) + 1))
nboz_val = int(
booz_kwargs.get("nboz")
or (np.max(np.abs(np.asarray(inputs0.xn))) // int(inputs0.nfp))
)
constants, grids = prepare_booz_xform_constants(
nfp=int(inputs0.nfp),
mboz=mboz_val,
nboz=nboz_val,
asym=bool(vmec_run.static.cfg.lasym),
xm=np.asarray(inputs0.xm),
xn=np.asarray(inputs0.xn),
xm_nyq=np.asarray(inputs0.xm_nyq),
xn_nyq=np.asarray(inputs0.xn_nyq),
)
ns_full = int(inputs0.rmnc.shape[0])
s_half_full = jnp.asarray(0.5 * (vmec_run.static.s[:-1] + vmec_run.static.s[1:]))
s_half_eval = jnp.concatenate([vmec_run.static.s[:1], s_half_full], axis=0)
profiles_half = eval_profiles(vmec_run.indata, s_half_eval)
flux = flux_profiles_from_indata(vmec_run.indata, vmec_run.static.s, signgs=int(vmec_run.signgs))
if cfg.surfaces is None:
surface_indices = None
s_selected = s_half_full
else:
s_vals = list(np.asarray(s_half_full))
surface_indices_list = []
for val in cfg.surfaces:
if isinstance(val, float) and 0.0 <= val <= 1.0:
best = min(range(ns_full), key=lambda i: abs(s_vals[i] - val))
surface_indices_list.append(best)
else:
surface_indices_list.append(int(val) - 1)
surface_indices = jnp.asarray(surface_indices_list, dtype=jnp.int32)
s_selected = jnp.take(s_half_full, surface_indices, axis=0)
control = cfg.to_control()
# Modes already filtered in the JAX pipeline; skip re-masking in Fourier sums.
control = replace(control, max_m_mode=-1, max_n_mode=-1)
# Precompute mode indices for JIT-safe slicing.
xm_b_np = np.asarray(grids.xm_b)
xn_b_np = np.asarray(grids.xn_b)
max_m = int(cfg.max_m_mode) if cfg.max_m_mode > 0 else int(np.max(np.abs(xm_b_np)))
max_n = int(cfg.max_n_mode) if cfg.max_n_mode > 0 else int(np.max(np.abs(xn_b_np)))
mode_indices = np.where((np.abs(xm_b_np) <= max_m) & (np.abs(xn_b_np) <= max_n))[0]
def _solve(state):
inputs = booz_xform_inputs_from_state(
state=state,
static=vmec_run.static,
indata=vmec_run.indata,
signgs=int(vmec_run.signgs),
trig=trig,
flux=flux,
profiles_half=profiles_half,
)
booz_out = booz_xform_jax_impl(
rmnc=inputs.rmnc,
zmns=inputs.zmns,
lmns=inputs.lmns,
bmnc=inputs.bmnc,
bsubumnc=inputs.bsubumnc,
bsubvmnc=inputs.bsubvmnc,
iota=inputs.iota,
xm=inputs.xm,
xn=inputs.xn,
xm_nyq=inputs.xm_nyq,
xn_nyq=inputs.xn_nyq,
constants=constants,
grids=grids,
bmns=inputs.bmns,
bsubumns=inputs.bsubumns,
bsubvmns=inputs.bsubvmns,
surface_indices=surface_indices,
)
booz_out["s_b"] = s_selected
booz_out["ns_b"] = ns_full
if surface_indices is not None:
booz_out["jlist"] = surface_indices + 1
booz = booz_xform_to_boozerdata_jax(
booz_out,
max_m_mode=cfg.max_m_mode,
max_n_mode=cfg.max_n_mode,
nfp_override=int(inputs0.nfp),
mode_indices=mode_indices,
)
return run_neo_from_boozer_jax(booz, control, skip_fourier_mask=True)
if jit:
_solve = jax.jit(_solve)
return _solve
This callable is the preferred entrypoint for repeated JAX-native studies where the same static geometry setup is reused many times.
Reference crosswalk¶
For readers comparing NEO_JAX to the established STELLOPT implementation, the existing routine-level crosswalk remains available in Reference Crosswalk.