API

Driver

High-level driver for NEO_JAX using Boozer data.

neo_jax.driver.compute_reference_jax(booz: BoozerData)[source]

JAX-friendly reference values.

neo_jax.driver.run_neo_from_boozer_jax(booz: BoozerData, control: ControlParams, *, skip_fourier_mask: bool = False, max_rational_field_periods: int | None = 100000, rational_surface_policy: str | None = None) NeoOutputs[source]

JAX surface scan over all requested surfaces (no Python loop).

High-level API

User-facing API helpers for NEO_JAX.

neo_jax.api.load_boozmn(boozmn_path: str | Path, *, max_m_mode: int = 0, max_n_mode: int = 0, surfaces: Sequence[int] | None = None) BoozerData[source]

Load a boozmn file into BoozerData for custom workflows.

neo_jax.api.run_booz_xform(booz: object, *, config: NeoConfig | None = None, surfaces: Sequence[int] | None = None, use_jax: bool = True, progress: bool | None = None, max_m_mode: int | None = None, max_n_mode: int | None = None, jax_surface_scan: bool = False) NeoResults | NeoOutputs[source]

Run NEO_JAX from a booz_xform_jax-style object or mapping.

neo_jax.api.run_boozer(booz: BoozerData, *, config: NeoConfig | None = None, surfaces: Sequence[int] | None = None, use_jax: bool = True, progress: bool | None = None, jax_surface_scan: bool = False) NeoResults | NeoOutputs[source]

Run NEO_JAX from a BoozerData object (e.g., booz_xform_jax output).

neo_jax.api.run_boozmn(boozmn_path: str | Path, *, config: NeoConfig | None = None, surfaces: Sequence[int] | None = None, use_jax: bool = True, progress: bool | None = None, jax_surface_scan: bool = False) NeoResults | NeoOutputs[source]

Run NEO_JAX from a boozmn file using a simplified configuration.

neo_jax.api.run_neo(source: BoozerData | str | Path | object, *, config: NeoConfig | None = None, surfaces: Sequence[int | float] | None = None, use_jax: bool = True, progress: bool | None = None, max_m_mode: int | None = None, max_n_mode: int | None = None, jax_surface_scan: bool = False) NeoResults | NeoOutputs[source]

Run NEO_JAX from a boozmn path, BoozerData, or booz_xform_jax-like object.

Configuration

User-friendly configuration for NEO_JAX runs.

class neo_jax.config.NeoConfig(surfaces: Sequence[int | float] | None = None, theta_n: int = 64, phi_n: int = 64, max_m_mode: int = 0, max_n_mode: int = 0, npart: int = 40, multra: int = 2, acc_req: float = 0.02, no_bins: int = 50, nstep_per: int = 20, nstep_min: int = 200, nstep_max: int = 500, calc_nstep_max: int = 0, max_rational_field_periods: int | None = 100000, rational_surface_policy: str = 'error', ref_swi: int = 2, write_progress: bool = False, write_diagnostic: bool = False)[source]

High-level configuration for NEO_JAX runs.

Surface selections may be specified as:

  • Integers (1-based NEO surface indices), or

  • Floats in [0, 1], interpreted as normalized toroidal flux s.

max_rational_field_periods is a safeguard for near-zero-iota surfaces. Set it to 0 to disable the guard explicitly.

rational_surface_policy controls what happens if the estimated rational-surface correction exceeds that limit:

  • "error": fail fast with a detailed diagnostic (default)

  • "approximate": skip the expensive rational-surface correction and return the base integration result with explicit diagnostics

to_control(*, in_file: str = 'boozmn', out_file: str = 'neo_out') ControlParams[source]

Convert to a ControlParams object (for CLI compatibility).

Results

User-friendly result containers for NEO_JAX.

class neo_jax.results.NeoResults(results: Iterable[NeoSurfaceResult])[source]

Container for multiple surface results with convenience accessors.

class neo_jax.results.NeoSurfaceResult(flux_index: int, s: float, r_eff: float, iota: float, b_ref: float, r_ref: float, epsilon_effective: float, epsilon_effective_by_class: ndarray, ctrone: float, ctrtot: float, bareph: float, barept: float, yps: float, diagnostics: Mapping[str, object])[source]

Results for a single flux surface.

neo_jax.results.neo_outputs_to_results(outputs: NeoOutputs, *, flux_indices: Sequence[int] | None = None) NeoResults[source]

Convert JAX-friendly NeoOutputs into NeoResults.

Plotting

Plotting helpers for NEO_JAX outputs.

neo_jax.plotting.plot_epsilon_effective(results: NeoResults, *, ax=None, x: str = 's', label: str | None = None) Tuple[object, object][source]

Plot epsilon effective vs a radial coordinate.

Parameters:
  • x ({"s", "sqrt_s", "r_eff"}) – Radial coordinate on the x-axis. Default is s.

  • (fig (Returns)

  • lazily. (ax). matplotlib is imported)

I/O

I/O helpers for NEO_JAX.

neo_jax.io.booz_xform_to_boozerdata(booz: object, *, max_m_mode: int = 0, max_n_mode: int = 0, fluxs_arr: Sequence[int] | None = None, use_jax: bool | None = None) BoozerData[source]

Convert booz_xform-style arrays into BoozerData.

The input object can be a mapping or an object with attributes matching the boozmn variable names (e.g., rmnc_b, zmns_b, pmns_b).

neo_jax.io.booz_xform_to_boozerdata_jax(booz: object, *, max_m_mode: int = 0, max_n_mode: int = 0, fluxs_arr: Sequence[int] | None = None, nfp_override: int | None = None, mode_indices: Sequence[int] | None = None) BoozerData[source]

JAX-friendly conversion from booz_xform outputs to BoozerData.

neo_jax.io.read_boozmn(path: str | Path, *, max_m_mode: int = 0, max_n_mode: int = 0, fluxs_arr: Sequence[int] | None = None, extension: str | None = None) BoozerData[source]

Read a boozmn netCDF file and return BoozerData.

This function follows the packing conventions used in STELLOPT’s read_boozer_mod, but only retains the surfaces requested.

neo_jax.io.read_boozmn_metadata(path: str | Path) dict[source]

Read minimal metadata (ns_b, jlist) from a boozmn file.

neo_jax.io.resolve_boozmn_path(base: str | Path, extension: str | None = None) Path[source]

Resolve a boozmn file from a base name and optional extension.

neo_jax.io.resolve_control_path(extension: str | None = None) Path[source]

Resolve NEO control file paths following xneo conventions.

Control

Control file parsing for NEO_JAX.

class neo_jax.control.ControlParams(in_file: 'str', out_file: 'str', fluxs_arr: 'Optional[List[int]]', theta_n: 'int', phi_n: 'int', max_m_mode: 'int', max_n_mode: 'int', npart: 'int', multra: 'int', acc_req: 'float', no_bins: 'int', nstep_per: 'int', nstep_min: 'int', nstep_max: 'int', calc_nstep_max: 'int', eout_swi: 'int', lab_swi: 'int', inp_swi: 'int', ref_swi: 'int', write_progress: 'int', write_output_files: 'int', spline_test: 'int', write_integrate: 'int', write_diagnostic: 'int', calc_cur: 'int', cur_file: 'str', npart_cur: 'int', alpha_cur: 'float', write_cur_inte: 'int')[source]

Integration

Integration routines for NEO_JAX.

class neo_jax.integrate.FlintParams(npart: 'int', multra: 'int', nstep_per: 'int', nstep_min: 'int', nstep_max: 'int', acc_req: 'float', no_bins: 'int', calc_nstep_max: 'int')[source]
class neo_jax.integrate.RhsEnv(splines: 'dict', grid: 'dict', eta: 'Array', bmod0: 'Array', iota: 'Array', curr_pol: 'Array | None' = None, curr_tor: 'Array | None' = None)[source]
class neo_jax.integrate.RhsState(isw: 'Array', ipa: 'Array', icount: 'Array', ipmax: 'Array', pard0: 'Array')[source]
neo_jax.integrate.flint_bo(surface, params: FlintParams, env: RhsEnv, nfp: int, rt0: float | None = None, *, Rmajor: float | None = None, diagnostic: bool = False, diagnostic_trap: bool = False, diagnostic_trap_path: str = 'diagnostic_first_trap.dat', diagnostic_snapshot: tuple[int, int] | None = None, diagnostic_snapshot_path: str = 'diagnostic_snapshot.dat', collect_convergence: bool = False, skip_rational_correction: bool = False)[source]

Python-loop port of flint_bo.f90 (not yet JIT-optimized).

neo_jax.integrate.flint_bo_jax(surface, params: FlintParams, env: RhsEnv, nfp: int, rt0: float | None = None, *, Rmajor: float | None = None, diagnostic_callback=None, diagnostic_trap_callback=None, diagnostic_snapshot: tuple[int, int] | None = None, diagnostic_snapshot_callback=None, convergence_callback=None, convergence_period_callback=None, convergence_step_callback=None, convergence_reset_callback=None, strict_parity: bool = False, skip_rational_correction: bool = False)[source]

JAX-friendly integration loop with rational-surface correction.

neo_jax.integrate.rhs_bo1(phi: Array, y: Array, state: RhsState, env: RhsEnv) Tuple[Array, RhsState][source]

Right-hand side for the field-line ODE (port of rhs_bo1.f90).

neo_jax.integrate.rk4_step(phi: Array, y: Array, state: RhsState, env: RhsEnv, h: float) Tuple[Array, Array, RhsState][source]

Run a single RK4 step, threading RHS state (port of rk4d_bo1.f90).

Geometry

Geometry evaluation and root finding for NEO_JAX.

neo_jax.geometry.neo_bderiv(theta: Array, phi: Array, b_spl: Array, grid: dict) Tuple[Array, Array, Array, Array, Array, Array][source]

Compute first and second derivatives of B using spline coefficients.

neo_jax.geometry.neo_eval(theta: Array, phi: Array, b_spl: Array, g_spl: Array, k_spl: Array, p_spl: Array, q_spl: Array | None, grid: dict) Tuple[Array, Array, Array, Array, Array][source]

Evaluate spline fields at a point.

Returns (bval, gval, kval, pval, qval).

neo_jax.geometry.neo_zeros2d(theta: Array, phi: Array, eps: float, iter_ma: int, b_spl: Array, grid: dict) Tuple[Array, Array, Array, Array][source]

Newton solver for finding extrema of B.

Returns (theta, phi, iter, error) where error=0 on success.

Fourier

Fourier summations and derived quantities for Boozer geometry.

neo_jax.fourier.derived_quantities(fourier: Dict[str, Array], curr_pol: Array, curr_tor: Array, iota: Array) Dict[str, Array][source]

Compute derived quantities from Fourier sums (neo_fourier.f90).

neo_jax.fourier.fourier_sums(theta_arr: Array, phi_arr: Array, rmnc: Array, zmns: Array, lmns: Array, bmnc: Array, ixm: Array, ixn: Array, nfp: int, max_m_mode: int, max_n_mode: int, *, skip_mask: bool = False, lasym: bool = False, rmns: Array | None = None, zmnc: Array | None = None, lmnc: Array | None = None, bmns: Array | None = None) Dict[str, Array][source]

Compute Fourier sums for a single flux surface.

This mirrors neo_fourier.f90 but uses direct trig evaluation. Set NEO_JAX_FOURIER_MODE=streamed to reduce memory by avoiding theta×phi×mode temporaries.

Splines

Cubic spline utilities ported from STELLOPT NEO.

neo_jax.splines.eva2d(spl: Array, ix: int, iy: int, dx: float, dy: float) float[source]

Evaluate 2D spline at given cell (port of eva2d.f90).

neo_jax.splines.eva2d_fd(spl: Array, ix: int, iy: int, dx: float, dy: float) Array[source]

Evaluate first derivatives of 2D spline (port of eva2d_fd.f90).

neo_jax.splines.eva2d_fd_jax(spl: Array, ix: Array, iy: Array, dx: Array, dy: Array) Array[source]

JAX-friendly first derivatives of spline.

neo_jax.splines.eva2d_jax(spl: Array, ix: Array, iy: Array, dx: Array, dy: Array) Array[source]

JAX-friendly spline evaluation.

neo_jax.splines.eva2d_sd(spl: Array, ix: int, iy: int, dx: float, dy: float) Array[source]

Evaluate second derivatives of 2D spline (port of eva2d_sd.f90).

neo_jax.splines.eva2d_sd_jax(spl: Array, ix: Array, iy: Array, dx: Array, dy: Array) Array[source]

JAX-friendly second derivatives of spline.

neo_jax.splines.poi2d(hx: float, hy: float, mx: int, my: int, xmin: float, xmax: float, ymin: float, ymax: float, x: float, y: float)[source]

Pointer calculation for spline evaluation (port of poi2d.f90).

Returns (ix, iy, dx, dy, ierr) where ix, iy are 0-based indices.

neo_jax.splines.poi2d_jax(hx: float, hy: float, mx: int, my: int, xmin: float, xmax: float, ymin: float, ymax: float, x: Array, y: Array)[source]

JAX-friendly pointer calculation for spline evaluation.

neo_jax.splines.spfper(np1: int, dtype=<class 'jax.numpy.float64'>) Tuple[Array, Array, Array][source]

Helper routine for periodic splines (port of spfper.f90).

neo_jax.splines.spl2d(f: Array, hx: float, hy: float, mx: int, my: int) Array[source]

Two-dimensional cubic spline coefficients.

Returns array of shape (4, 4, nx, ny).

neo_jax.splines.splper(y: Array, h: float) Tuple[Array, Array, Array][source]

Periodic cubic spline coefficients.

Mirrors splper.f90.

neo_jax.splines.splreg(y: Array, h: float) Tuple[Array, Array, Array][source]

Regular (non-periodic) cubic spline coefficients.

Mirrors splreg.f90.

Grids

Grid preparation utilities.

neo_jax.grids.prepare_grids(theta_n: int, phi_n: int, nfp: int) Dict[str, Array | float | int][source]

Prepare theta/phi grids and spacing.

Mirrors neo_prep.f90 grid construction.

Surface

Surface initialization and spline construction.

class neo_jax.surface.SurfaceData(b_min: 'Array', b_max: 'Array', theta_bmin: 'Array', phi_bmin: 'Array', theta_bmax: 'Array', phi_bmax: 'Array', bmref: 'Array', fields: 'Dict[str, Array]', splines: 'Dict[str, Array]')[source]
neo_jax.surface.init_surface(theta_arr: Array, phi_arr: Array, coeffs: Dict[str, Array], ixm: Array, ixn: Array, nfp: int, max_m_mode: int, max_n_mode: int, curr_pol: Array, curr_tor: Array, iota: Array, grid: Dict[str, float], calc_cur: bool = False, use_jax: bool = False, skip_mask: bool = False) SurfaceData[source]

Initialize a single flux surface: Fourier sums, derived fields, splines, B min/max.

Pipeline

Pipeline helpers for vmec_jax -> booz_xform_jax -> neo_jax.

neo_jax.pipeline.booz_xform_from_vmec_state_jax(*, vmec_run: Any, mboz: int | None = None, nboz: int | None = None, surfaces: Sequence[int | float] | None = None, jit: bool = True) Mapping[str, Any][source]

JAX-native VMEC state -> Boozer transform using booz_xform_jax.

neo_jax.pipeline.booz_xform_from_vmec_wout(wout: Any, *, mboz: int | None = None, nboz: int | None = None, surfaces: Sequence[int | float] | None = None, flux: bool = False, jit: bool = True) Mapping[str, Any][source]

Run booz_xform_jax on an in-memory VMEC wout object.

Parameters:
  • wout – VMEC wout-like object (for example vmec_jax.WoutData).

  • mboz – Boozer resolution. If None, defaults to VMEC mpol/ntor values.

  • nboz – Boozer resolution. If None, defaults to VMEC mpol/ntor values.

  • surfaces – Optional surface indices or s values in [0, 1]. If omitted, all VMEC half-grid surfaces are used.

  • flux – If True, attempt to load flux profile arrays from wout.

  • jit – If True, jit-compile the Boozer transform.

neo_jax.pipeline.build_vmec_boozer_neo_jax(vmec_run: Any, *, booz_kwargs: dict | None = None, neo_config: NeoConfig | None = None, jit: bool = True)[source]

Return a callable solve(state) for the JAX-native VMEC→Boozer→NEO path.

This precomputes Boozer constants and surface selections so the returned function is suitable for repeated calls (and optional JIT).

neo_jax.pipeline.run_boozer_to_neo(booz_output: Mapping[str, Any], *, config: NeoConfig | None = None, use_jax: bool = True, progress: bool | None = None) Any[source]

Run NEO_JAX from a booz_xform_jax output mapping.

neo_jax.pipeline.run_vmec_boozer_neo(vmec_source: Any, *, booz_xform_fn: Callable[[...], Mapping[str, Any]] | None = None, booz_kwargs: dict | None = None, vmec_kwargs: dict | None = None, neo_config: NeoConfig | None = None, use_jax: bool = True, progress: bool | None = None, fast_bcovar: bool = True) Any[source]

Run vmec_jax -> booz_xform_jax -> neo_jax in one workflow.

This requires a JAX-native booz_xform_fn (for example from booz_xform_jax.jax_api). vmec_source may be a vmec_jax.FixedBoundaryRun, a vmec_jax.WoutData object, or a path to a VMEC input file.

neo_jax.pipeline.run_vmec_boozer_neo_jax(vmec_run: Any, *, booz_kwargs: dict | None = None, neo_config: NeoConfig | None = None, jax_surface_scan: bool = True, progress: bool | None = None) Any[source]

JAX-native VMEC -> Boozer -> NEO pipeline using the JAX surface scan.