Validation¶
NEO_JAX validation is organized into three layers:
Geometry parity: Fourier reconstruction and spline evaluation match the Fortran outputs for
R,Z,B, and derived quantities.Integration parity: field-line integrals and trapped-particle sums reproduce the reference
neo_outresults on curated fixtures.End-to-end parity: the CLI output matches
xneofile-by-file on the supported legacy cases, using committed reference fixtures generated withxneo.
Current reference cases include:
ORBITS(tests/fixtures/orbits)NCSXtutorial example (tests/fixtures/ncsx)LandremanPaul2021_QA_lowres(tests/fixtures/landreman_qa_lowres)constellarationlow-|iota|fixtures (tests/fixtures/constellaration)a synthetic one-surface ORBITS legacy case used to validate
neo_out.*,neolog.*,diagnostic*.dat,conver.dat, and the legacy*_arr.datdumps against storedxneoreference outputsa synthetic one-surface ORBITS
calc_cur = 1case used to validateneo_cur.*andcurrent.datagainst storedxneoreference outputsa synthetic one-surface NCSX case used to validate
neo_out.*andneolog.*against storedxneoreference outputs
The large NCSX Boozer input is distributed as an external release asset rather
than being committed directly in the repository. Default CI and slim checkouts
therefore skip NCSX fixture consumers unless
NEO_JAX_FETCH_EXTERNAL_FIXTURES=1 is set or the file is already present in
the local fixture cache.
Legacy CLI parity¶
The CLI regression coverage in tests/regression/test_cli_legacy.py runs the
JAX CLI against frozen reference outputs committed under
tests/fixtures/cli_legacy/. Those files were generated once with the
STELLOPT xneo executable and then checked into the repository so the tests
do not require an external binary.
Current checks:
exact text equality for: -
neo_out.*-neo_cur.*-neolog.*-diagnostic.dat-diagnostic_add.dat-diagnostic_bigint.dat-conver.daton the synthetic ORBITS parity casenumerical equality, up to floating-point roundoff, for: -
dimension.dat-es_arr.dat- all legacy*_arr.datgeometry dumps -current.dattoken streams, including matchingNaN/Infinitymasks and tight tolerances on finite valuescontrol-file search-order parity for: -
neo_param.<extension>-neo_param.in-neo_in.<extension>slow-fixture CLI parity: - exact
neo_out.*/neolog.*parity onORBITS_FAST- exactconver.datcolumns 1-4 parity onORBITS_FAST- approximatelyrtol=5e-3parity onncsx_c09r00_free_fastoptional GPU smoke parity: - CLI CPU-vs-GPU agreement on a one-surface ORBITS case - Python API CPU-vs-GPU agreement through
run_neo(...)
For the dense ORBITS_FAST case, NEO_JAX also exposes
NEO_JAX_WRITE_IPMAX_DEBUG=1 to emit diagnostic_ipmax_jax.dat. That
debug trace is used to compare the per-step trapped-amplitude history against
the STELLOPT solver when investigating the remaining conver.dat fifth-column
discrepancy.
Supported legacy scope:
calc_cur = 0parity is tested and supportedcalc_cur = 1parity is tested and supported
Low-|iota| regression coverage¶
The constellaration fixtures contain surfaces with very small rotational
transform. For these cases, the legacy rational-surface correction implies
nfp_rat ~= ceil(1 / acc_req / |iota|)
so the required field-period count can become enormous. Physically, this is the
near-zero-|iota| regime. Numerically, it can make the legacy correction
look hung simply because the requested work is very large.
tests/regression/test_constellaration_guard.py now verifies that:
both reported Boozer files fail fast instead of hanging
the exception message explains that the requested rational correction is too large
progress=Trueprints the detailed preflight diagnostic before abortingthe same safeguard applies to the JAX surface-scan backend
rational_surface_policy="approximate"returns a finite controlled result with explicit approximation diagnosticsapproximate-mode JAX and Python per-surface backends stay close on the same low-resolution pathological case
requesting
jax_surface_scan=Truewith approximate mode falls back to the per-surface path instead of trying to support approximation inside the scan backend
The broader parity suites for ORBITS, NCSX, LandremanPaul2021_QA_lowres, the legacy CLI outputs, and the public API continue to exercise the unchanged paths. Approximate mode is opt-in, and the default safeguard does not alter those existing cases unless the preflight rational-work estimate actually exceeds the configured limit.
Operational guidance:
keep the default safeguard for validation and parity work
use
rational_surface_policy="approximate"when a controlled fallback is more useful than an early errorset
max_rational_field_periods=0only when you explicitly want the full exact legacy correction, even if it may take hoursavoid surfaces with
|iota|near zero, or loosenacc_req, when scanning new equilibria
Precision¶
NEO_JAX enables 64-bit JAX precision by default to match Fortran parity. You can override this behavior by setting either:
NEO_JAX_ENABLE_X64=0(NEO_JAX-specific)JAX_ENABLE_X64=0(global JAX setting)
Fast vs. full ORBITS parity:
The default CLI regression suite runs the dense Landreman fixture plus reduced ORBITS / NCSX mini cases that finish quickly in CI.
Full solver parity checks for dense ORBITS and NCSX fixtures remain available behind
NEO_JAX_RUN_SLOW=1in the public-API regression tests.
Planned metrics:
Relative error on
epstotandepsparvs reference output.Consistency of derived quantities (
kg,pard,sqrt(g^{11})).Regression of rational-surface handling and bin-averaged convergence.
Benchmarking¶
The benchmarks/benchmark_orbits.py script measures runtime on the ORBITS
fixture using either the Python or JAX backend.
For JAX performance runs, the driver uses a JIT-compiled kernel by default.
Set NEO_JAX_DISABLE_JIT=1 to force eager execution when debugging.
Both benchmark scripts accept --warmup to separate compile time from
steady-state runtime.
Profiling¶
Use benchmarks/profile_run.py to capture JAX traces and XLA dumps for
performance and memory analysis:
# Fast ORBITS trace with X64 parity settings
PYTHONPATH=. python benchmarks/profile_run.py --jax --enable-x64 \
--trace-dir profiles/orbits_fast
# Dump XLA HLO/LLVM artifacts (CPU or GPU)
PYTHONPATH=. python benchmarks/profile_run.py --jax --enable-x64 \
--xla-dump-dir profiles/xla_orbits_fast
The trace directory can be opened with TensorBoard:
tensorboard --logdir profiles
GPU Run Guide¶
To benchmark on GPU, ensure a CUDA-enabled JAX build is installed and set the runtime environment before running the benchmark scripts:
export JAX_PLATFORM_NAME=gpu
export JAX_ENABLE_X64=1
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8
python benchmarks/benchmark_orbits.py --jax
python benchmarks/benchmark_ncsx.py --jax
Tips:
Use
JAX_ENABLE_X64=1to match Fortran parity expectations.If you see out-of-memory errors, lower
XLA_PYTHON_CLIENT_MEM_FRACTIONor setXLA_PYTHON_CLIENT_PREALLOCATE=false.
GPU validation on office¶
NEO_JAX was revalidated on March 10, 2026 on the office workstation
(pop-os) with 2x NVIDIA RTX A4000 GPUs and JAX 0.6.2 in
/home/rjorge/venvs/vmec_jax_gpu_bench.
The GPU smoke suite is:
env NEO_JAX_RUN_GPU=1 JAX_PLATFORM_NAME=gpu python -m pytest -q \
tests/regression/test_gpu_smoke.py
That test file verifies:
the legacy CLI produces the same one-surface ORBITS
neo_out.*values on CPU and GPUthe Python API produces the same ORBITS effective-ripple result on CPU and GPU
the default CLI progress log reports the active JAX runtime
In addition, the user-facing examples/ncsx_epsilon_effective_plot.py script
was run on the same GPU host with MPLBACKEND=Agg and produced
examples/ncsx_eps_eff_vs_s.png successfully.
Measured cold-run snapshots on office:
Path |
Case |
CPU |
GPU |
CPU RSS MiB |
GPU RSS MiB |
|---|---|---|---|---|---|
Legacy CLI |
|
|
|
|
|
Python API |
ORBITS single-surface smoke |
|
|
n/a |
n/a |
At the current problem sizes, the GPU path is functional and parity-checked but still compile-bound. For the legacy CLI and the small ORBITS API smoke, the GPU is slower than CPU because compile and launch overhead dominate the solve.