Performance¶
Performance in NEO_JAX is dominated by three costs:
Fourier reconstruction and spline construction on each surface
field-line integration and trapped-particle accumulation
JAX compile time when a new static configuration is seen
Fourier Summation Modes¶
NEO_JAX offers two Fourier evaluation paths:
NEO_JAX_FOURIER_MODE=vectorized(default): fastest, but allocatestheta × phi × modetemporaries.NEO_JAX_FOURIER_MODE=streamed: avoids 3D temporaries by streaming over Fourier modes; reduces memory at the cost of additional runtime.
NCSX benchmark comparison (10 surfaces, CPU warmup run, /usr/bin/time -l):
Mode |
Total time |
Max RSS |
|---|---|---|
Vectorized |
51.37 s |
4.45 GB |
Streamed |
58.78 s |
2.55 GB |
Kernel Fusion Notes¶
The JAX backend inlines the RHS evaluation, RK4 staging, and trapped-particle
updates inside the scan body. This keeps neo_eval and the RK4 stages in a
single fused region and reduces the number of separate kernels emitted by XLA.
The implementation lives in neo_jax.integrate.flint_bo_jax and replaces the
previous rk4_step + _process_trapped call boundary for the JIT path.
XLA Memory Hotspots (ORBITS_FAST)¶
XLA memory reports for the ORBITS_FAST profile show the largest allocations in
jit_flint_bo_jax are the spline coefficient arrays:
b_spl,g_spl,k_spl,p_splwith shape[4, 4, 25, 25].
In the scan body, the largest allocations are the per-step state and temporary
arrays with shape [25, 196] (theta × particle grids). These correspond to
RK4 staging and trapped-particle updates.
Profiling Workflow¶
Use benchmarks/profile_run.py to generate traces and XLA dumps:
# Streamed Fourier profile
NEO_JAX_FOURIER_MODE=streamed PYTHONPATH=. \
python benchmarks/profile_run.py --jax --enable-x64 \
--trace-dir profiles/trace_orbits_fast_streamed \
--xla-dump-dir profiles/xla_orbits_fast_streamed
Open the trace directory with TensorBoard to inspect kernel-level hotspots.
For end-to-end VMEC→Boozer→NEO profiling (including the Boozer transform),
use benchmarks/profile_vmec_boozer_pipeline.py. It records a JAX trace and
emits an HLO text dump for kernel-level inspection:
python benchmarks/profile_vmec_boozer_pipeline.py \
--case circular_tokamak \
--trace-dir profiles/vmec_boozer_neo_trace \
--hlo-out profiles/vmec_boozer_neo.hlo.txt
JIT Pipeline Reuse¶
For end-to-end VMEC→Boozer→NEO workflows, prefer
neo_jax.build_vmec_boozer_neo_jax() to precompute Boozer constants and
reuse a single compiled callable. This avoids recompiling the Boozer transform
and NEO scan in optimization loops.
Benchmark JIT reuse with:
python benchmarks/benchmark_vmec_boozer_pipeline.py --case circular_tokamak --repeats 3
For CI, benchmarks/ci_perf_check.py provides a small regression guardrail
using a tiny pipeline case and configurable thresholds.