API¶
Driver¶
High-level driver for NEO_JAX using Boozer data.
High-level API¶
User-facing API helpers for NEO_JAX.
- neo_jax.api.load_boozmn(boozmn_path: str | Path, *, max_m_mode: int = 0, max_n_mode: int = 0, surfaces: Sequence[int] | None = None) BoozerData[source]¶
Load a boozmn file into BoozerData for custom workflows.
- neo_jax.api.run_booz_xform(booz: object, *, config: NeoConfig | None = None, surfaces: Sequence[int] | None = None, use_jax: bool = True, progress: bool | None = None, max_m_mode: int | None = None, max_n_mode: int | None = None, jax_surface_scan: bool = False) NeoResults | NeoOutputs[source]¶
Run NEO_JAX from a booz_xform_jax-style object or mapping.
- neo_jax.api.run_boozer(booz: BoozerData, *, config: NeoConfig | None = None, surfaces: Sequence[int] | None = None, use_jax: bool = True, progress: bool | None = None, jax_surface_scan: bool = False) NeoResults | NeoOutputs[source]¶
Run NEO_JAX from a BoozerData object (e.g., booz_xform_jax output).
- neo_jax.api.run_boozmn(boozmn_path: str | Path, *, config: NeoConfig | None = None, surfaces: Sequence[int] | None = None, use_jax: bool = True, progress: bool | None = None, jax_surface_scan: bool = False) NeoResults | NeoOutputs[source]¶
Run NEO_JAX from a boozmn file using a simplified configuration.
- neo_jax.api.run_neo(source: BoozerData | str | Path | object, *, config: NeoConfig | None = None, surfaces: Sequence[int | float] | None = None, use_jax: bool = True, progress: bool | None = None, max_m_mode: int | None = None, max_n_mode: int | None = None, jax_surface_scan: bool = False) NeoResults | NeoOutputs[source]¶
Run NEO_JAX from a boozmn path, BoozerData, or booz_xform_jax-like object.
Configuration¶
User-friendly configuration for NEO_JAX runs.
- class neo_jax.config.NeoConfig(surfaces: Sequence[int | float] | None = None, theta_n: int = 64, phi_n: int = 64, max_m_mode: int = 0, max_n_mode: int = 0, npart: int = 40, multra: int = 2, acc_req: float = 0.02, no_bins: int = 50, nstep_per: int = 20, nstep_min: int = 200, nstep_max: int = 500, calc_nstep_max: int = 0, max_rational_field_periods: int | None = 100000, rational_surface_policy: str = 'error', ref_swi: int = 2, write_progress: bool = False, write_diagnostic: bool = False)[source]¶
High-level configuration for NEO_JAX runs.
Surface selections may be specified as:
Integers (1-based NEO surface indices), or
Floats in [0, 1], interpreted as normalized toroidal flux
s.
max_rational_field_periodsis a safeguard for near-zero-iota surfaces. Set it to0to disable the guard explicitly.rational_surface_policycontrols what happens if the estimated rational-surface correction exceeds that limit:"error": fail fast with a detailed diagnostic (default)"approximate": skip the expensive rational-surface correction and return the base integration result with explicit diagnostics
- to_control(*, in_file: str = 'boozmn', out_file: str = 'neo_out') ControlParams[source]¶
Convert to a ControlParams object (for CLI compatibility).
Results¶
User-friendly result containers for NEO_JAX.
- class neo_jax.results.NeoResults(results: Iterable[NeoSurfaceResult])[source]¶
Container for multiple surface results with convenience accessors.
- class neo_jax.results.NeoSurfaceResult(flux_index: int, s: float, r_eff: float, iota: float, b_ref: float, r_ref: float, epsilon_effective: float, epsilon_effective_by_class: ndarray, ctrone: float, ctrtot: float, bareph: float, barept: float, yps: float, diagnostics: Mapping[str, object])[source]¶
Results for a single flux surface.
Plotting¶
Plotting helpers for NEO_JAX outputs.
- neo_jax.plotting.plot_epsilon_effective(results: NeoResults, *, ax=None, x: str = 's', label: str | None = None) Tuple[object, object][source]¶
Plot epsilon effective vs a radial coordinate.
- Parameters:
x ({"s", "sqrt_s", "r_eff"}) – Radial coordinate on the x-axis. Default is
s.(fig (Returns)
lazily. (ax). matplotlib is imported)
I/O¶
I/O helpers for NEO_JAX.
- neo_jax.io.booz_xform_to_boozerdata(booz: object, *, max_m_mode: int = 0, max_n_mode: int = 0, fluxs_arr: Sequence[int] | None = None, use_jax: bool | None = None) BoozerData[source]¶
Convert booz_xform-style arrays into BoozerData.
The input object can be a mapping or an object with attributes matching the boozmn variable names (e.g.,
rmnc_b,zmns_b,pmns_b).
- neo_jax.io.booz_xform_to_boozerdata_jax(booz: object, *, max_m_mode: int = 0, max_n_mode: int = 0, fluxs_arr: Sequence[int] | None = None, nfp_override: int | None = None, mode_indices: Sequence[int] | None = None) BoozerData[source]¶
JAX-friendly conversion from booz_xform outputs to BoozerData.
- neo_jax.io.read_boozmn(path: str | Path, *, max_m_mode: int = 0, max_n_mode: int = 0, fluxs_arr: Sequence[int] | None = None, extension: str | None = None) BoozerData[source]¶
Read a boozmn netCDF file and return BoozerData.
This function follows the packing conventions used in STELLOPT’s read_boozer_mod, but only retains the surfaces requested.
- neo_jax.io.read_boozmn_metadata(path: str | Path) dict[source]¶
Read minimal metadata (ns_b, jlist) from a boozmn file.
Control¶
Control file parsing for NEO_JAX.
- class neo_jax.control.ControlParams(in_file: 'str', out_file: 'str', fluxs_arr: 'Optional[List[int]]', theta_n: 'int', phi_n: 'int', max_m_mode: 'int', max_n_mode: 'int', npart: 'int', multra: 'int', acc_req: 'float', no_bins: 'int', nstep_per: 'int', nstep_min: 'int', nstep_max: 'int', calc_nstep_max: 'int', eout_swi: 'int', lab_swi: 'int', inp_swi: 'int', ref_swi: 'int', write_progress: 'int', write_output_files: 'int', spline_test: 'int', write_integrate: 'int', write_diagnostic: 'int', calc_cur: 'int', cur_file: 'str', npart_cur: 'int', alpha_cur: 'float', write_cur_inte: 'int')[source]¶
Integration¶
Integration routines for NEO_JAX.
- class neo_jax.integrate.FlintParams(npart: 'int', multra: 'int', nstep_per: 'int', nstep_min: 'int', nstep_max: 'int', acc_req: 'float', no_bins: 'int', calc_nstep_max: 'int')[source]¶
- class neo_jax.integrate.RhsEnv(splines: 'dict', grid: 'dict', eta: 'Array', bmod0: 'Array', iota: 'Array', curr_pol: 'Array | None' = None, curr_tor: 'Array | None' = None)[source]¶
- class neo_jax.integrate.RhsState(isw: 'Array', ipa: 'Array', icount: 'Array', ipmax: 'Array', pard0: 'Array')[source]¶
- neo_jax.integrate.flint_bo(surface, params: FlintParams, env: RhsEnv, nfp: int, rt0: float | None = None, *, Rmajor: float | None = None, diagnostic: bool = False, diagnostic_trap: bool = False, diagnostic_trap_path: str = 'diagnostic_first_trap.dat', diagnostic_snapshot: tuple[int, int] | None = None, diagnostic_snapshot_path: str = 'diagnostic_snapshot.dat', collect_convergence: bool = False, skip_rational_correction: bool = False)[source]¶
Python-loop port of flint_bo.f90 (not yet JIT-optimized).
- neo_jax.integrate.flint_bo_jax(surface, params: FlintParams, env: RhsEnv, nfp: int, rt0: float | None = None, *, Rmajor: float | None = None, diagnostic_callback=None, diagnostic_trap_callback=None, diagnostic_snapshot: tuple[int, int] | None = None, diagnostic_snapshot_callback=None, convergence_callback=None, convergence_period_callback=None, convergence_step_callback=None, convergence_reset_callback=None, strict_parity: bool = False, skip_rational_correction: bool = False)[source]¶
JAX-friendly integration loop with rational-surface correction.
Geometry¶
Geometry evaluation and root finding for NEO_JAX.
- neo_jax.geometry.neo_bderiv(theta: Array, phi: Array, b_spl: Array, grid: dict) Tuple[Array, Array, Array, Array, Array, Array][source]¶
Compute first and second derivatives of B using spline coefficients.
Fourier¶
Fourier summations and derived quantities for Boozer geometry.
- neo_jax.fourier.derived_quantities(fourier: Dict[str, Array], curr_pol: Array, curr_tor: Array, iota: Array) Dict[str, Array][source]¶
Compute derived quantities from Fourier sums (neo_fourier.f90).
- neo_jax.fourier.fourier_sums(theta_arr: Array, phi_arr: Array, rmnc: Array, zmns: Array, lmns: Array, bmnc: Array, ixm: Array, ixn: Array, nfp: int, max_m_mode: int, max_n_mode: int, *, skip_mask: bool = False, lasym: bool = False, rmns: Array | None = None, zmnc: Array | None = None, lmnc: Array | None = None, bmns: Array | None = None) Dict[str, Array][source]¶
Compute Fourier sums for a single flux surface.
This mirrors neo_fourier.f90 but uses direct trig evaluation. Set NEO_JAX_FOURIER_MODE=streamed to reduce memory by avoiding theta×phi×mode temporaries.
Splines¶
Cubic spline utilities ported from STELLOPT NEO.
- neo_jax.splines.eva2d(spl: Array, ix: int, iy: int, dx: float, dy: float) float[source]¶
Evaluate 2D spline at given cell (port of eva2d.f90).
- neo_jax.splines.eva2d_fd(spl: Array, ix: int, iy: int, dx: float, dy: float) Array[source]¶
Evaluate first derivatives of 2D spline (port of eva2d_fd.f90).
- neo_jax.splines.eva2d_fd_jax(spl: Array, ix: Array, iy: Array, dx: Array, dy: Array) Array[source]¶
JAX-friendly first derivatives of spline.
- neo_jax.splines.eva2d_jax(spl: Array, ix: Array, iy: Array, dx: Array, dy: Array) Array[source]¶
JAX-friendly spline evaluation.
- neo_jax.splines.eva2d_sd(spl: Array, ix: int, iy: int, dx: float, dy: float) Array[source]¶
Evaluate second derivatives of 2D spline (port of eva2d_sd.f90).
- neo_jax.splines.eva2d_sd_jax(spl: Array, ix: Array, iy: Array, dx: Array, dy: Array) Array[source]¶
JAX-friendly second derivatives of spline.
- neo_jax.splines.poi2d(hx: float, hy: float, mx: int, my: int, xmin: float, xmax: float, ymin: float, ymax: float, x: float, y: float)[source]¶
Pointer calculation for spline evaluation (port of poi2d.f90).
Returns (ix, iy, dx, dy, ierr) where ix, iy are 0-based indices.
- neo_jax.splines.poi2d_jax(hx: float, hy: float, mx: int, my: int, xmin: float, xmax: float, ymin: float, ymax: float, x: Array, y: Array)[source]¶
JAX-friendly pointer calculation for spline evaluation.
- neo_jax.splines.spfper(np1: int, dtype=<class 'jax.numpy.float64'>) Tuple[Array, Array, Array][source]¶
Helper routine for periodic splines (port of spfper.f90).
- neo_jax.splines.spl2d(f: Array, hx: float, hy: float, mx: int, my: int) Array[source]¶
Two-dimensional cubic spline coefficients.
Returns array of shape (4, 4, nx, ny).
Grids¶
Grid preparation utilities.
Surface¶
Surface initialization and spline construction.
- class neo_jax.surface.SurfaceData(b_min: 'Array', b_max: 'Array', theta_bmin: 'Array', phi_bmin: 'Array', theta_bmax: 'Array', phi_bmax: 'Array', bmref: 'Array', fields: 'Dict[str, Array]', splines: 'Dict[str, Array]')[source]¶
- neo_jax.surface.init_surface(theta_arr: Array, phi_arr: Array, coeffs: Dict[str, Array], ixm: Array, ixn: Array, nfp: int, max_m_mode: int, max_n_mode: int, curr_pol: Array, curr_tor: Array, iota: Array, grid: Dict[str, float], calc_cur: bool = False, use_jax: bool = False, skip_mask: bool = False) SurfaceData[source]¶
Initialize a single flux surface: Fourier sums, derived fields, splines, B min/max.
Pipeline¶
Pipeline helpers for vmec_jax -> booz_xform_jax -> neo_jax.
- neo_jax.pipeline.booz_xform_from_vmec_state_jax(*, vmec_run: Any, mboz: int | None = None, nboz: int | None = None, surfaces: Sequence[int | float] | None = None, jit: bool = True) Mapping[str, Any][source]¶
JAX-native VMEC state -> Boozer transform using booz_xform_jax.
- neo_jax.pipeline.booz_xform_from_vmec_wout(wout: Any, *, mboz: int | None = None, nboz: int | None = None, surfaces: Sequence[int | float] | None = None, flux: bool = False, jit: bool = True) Mapping[str, Any][source]¶
Run booz_xform_jax on an in-memory VMEC wout object.
- Parameters:
wout – VMEC wout-like object (for example
vmec_jax.WoutData).mboz – Boozer resolution. If
None, defaults to VMEC mpol/ntor values.nboz – Boozer resolution. If
None, defaults to VMEC mpol/ntor values.surfaces – Optional surface indices or s values in [0, 1]. If omitted, all VMEC half-grid surfaces are used.
flux – If
True, attempt to load flux profile arrays fromwout.jit – If
True, jit-compile the Boozer transform.
- neo_jax.pipeline.build_vmec_boozer_neo_jax(vmec_run: Any, *, booz_kwargs: dict | None = None, neo_config: NeoConfig | None = None, jit: bool = True)[source]¶
Return a callable solve(state) for the JAX-native VMEC→Boozer→NEO path.
This precomputes Boozer constants and surface selections so the returned function is suitable for repeated calls (and optional JIT).
- neo_jax.pipeline.run_boozer_to_neo(booz_output: Mapping[str, Any], *, config: NeoConfig | None = None, use_jax: bool = True, progress: bool | None = None) Any[source]¶
Run NEO_JAX from a booz_xform_jax output mapping.
- neo_jax.pipeline.run_vmec_boozer_neo(vmec_source: Any, *, booz_xform_fn: Callable[[...], Mapping[str, Any]] | None = None, booz_kwargs: dict | None = None, vmec_kwargs: dict | None = None, neo_config: NeoConfig | None = None, use_jax: bool = True, progress: bool | None = None, fast_bcovar: bool = True) Any[source]¶
Run vmec_jax -> booz_xform_jax -> neo_jax in one workflow.
This requires a JAX-native
booz_xform_fn(for example frombooz_xform_jax.jax_api).vmec_sourcemay be avmec_jax.FixedBoundaryRun, avmec_jax.WoutDataobject, or a path to a VMEC input file.