Source code for neo_jax.driver

"""High-level driver for NEO_JAX using Boozer data."""

from __future__ import annotations

import os
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple

import numpy as np
import jax
import jax.numpy as jnp

from .control import ControlParams
from .current import CurrentParams, flint_cur_jax
from .data_models import BoozerData
from .grids import prepare_grids
from .integrate import FlintParams, RhsEnv, flint_bo, flint_bo_jax
from .io import read_boozmn
from .legacy import LegacyNeoWriter, build_fortran_line
from .results import NeoResults, NeoSurfaceResult
from .surface import init_surface
from .data_models import NeoOutputs

DEFAULT_MAX_RATIONAL_FIELD_PERIODS = 100_000


def compute_reference(booz: BoozerData) -> Dict[str, float]:
    m0_idx = np.where((booz.ixm == 0) & (booz.ixn == 0))[0][0]
    rt0 = float(booz.rmnc[0, m0_idx])
    bmref_g = float(booz.bmnc[0, m0_idx])
    return {"rt0": rt0, "Rmajor": rt0, "bmref_g": bmref_g}


[docs] def compute_reference_jax(booz: BoozerData): """JAX-friendly reference values.""" m0_mask = (booz.ixm == 0) & (booz.ixn == 0) m0_idx = jnp.where(m0_mask, size=1, fill_value=0)[0] rt0 = jnp.squeeze(booz.rmnc[0, m0_idx]) bmref_g = jnp.squeeze(booz.bmnc[0, m0_idx]) return rt0, bmref_g
[docs] def run_neo_from_boozer_jax( booz: BoozerData, control: ControlParams, *, skip_fourier_mask: bool = False, max_rational_field_periods: int | None = DEFAULT_MAX_RATIONAL_FIELD_PERIODS, rational_surface_policy: str | None = None, ) -> NeoOutputs: """JAX surface scan over all requested surfaces (no Python loop).""" booz = BoozerData( rmnc=jnp.asarray(booz.rmnc), zmns=jnp.asarray(booz.zmns), lmns=jnp.asarray(booz.lmns), bmnc=jnp.asarray(booz.bmnc), ixm=jnp.asarray(booz.ixm), ixn=jnp.asarray(booz.ixn), es=jnp.asarray(booz.es), iota=jnp.asarray(booz.iota), curr_pol=jnp.asarray(booz.curr_pol), curr_tor=jnp.asarray(booz.curr_tor), nfp=int(booz.nfp), ) grid = prepare_grids(control.theta_n, control.phi_n, booz.nfp) def _max_abs_mode(arr): if isinstance(arr, jax.Array): return jnp.max(jnp.abs(arr)) return int(np.max(np.abs(arr))) max_m_mode = control.max_m_mode if control.max_m_mode > 0 else _max_abs_mode(booz.ixm) max_n_mode = control.max_n_mode if control.max_n_mode > 0 else _max_abs_mode(booz.ixn) if control.fluxs_arr: if booz.rmnc.shape[0] == len(control.fluxs_arr): surf_indices = list(range(booz.rmnc.shape[0])) else: surf_indices = [i - 1 for i in control.fluxs_arr] flux_indices = list(control.fluxs_arr) else: surf_indices = list(range(booz.rmnc.shape[0])) flux_indices = [idx + 1 for idx in surf_indices] work_limit = _resolve_max_rational_field_periods(max_rational_field_periods) policy = _resolve_rational_surface_policy(rational_surface_policy) if work_limit is not None: for local_idx, surf_idx in enumerate(surf_indices): work = _estimate_rational_work( float(booz.iota[surf_idx]), control.acc_req, nstep_per=control.nstep_per, npart=control.npart, multra=control.multra, ) if work["field_periods"] > work_limit: if policy == "approximate": raise RuntimeError( "JAX surface scan does not support rational_surface_policy='approximate'. " "Call run_neo(..., jax_surface_scan=False, rational_surface_policy='approximate') instead." ) flux_index = flux_indices[local_idx] _raise_rational_work_limit( flux_index=flux_index, s_val=float(booz.es[surf_idx]), iota=float(booz.iota[surf_idx]), work=work, work_limit=work_limit, ) surf_indices_j = jnp.asarray(surf_indices, dtype=jnp.int32) flux_indices_j = jnp.asarray(flux_indices, dtype=jnp.int32) rt0, bmref_g = compute_reference_jax(booz) params = FlintParams( npart=control.npart, multra=control.multra, nstep_per=control.nstep_per, nstep_min=control.nstep_min, nstep_max=control.nstep_max, acc_req=control.acc_req, no_bins=control.no_bins, calc_nstep_max=control.calc_nstep_max, ) def _solve_surface(surf_idx): coeffs = { "rmnc": booz.rmnc[surf_idx], "zmns": booz.zmns[surf_idx], "lmns": booz.lmns[surf_idx], "bmnc": booz.bmnc[surf_idx], } surface = init_surface( grid["theta_arr"], grid["phi_arr"], coeffs, booz.ixm, booz.ixn, nfp=booz.nfp, max_m_mode=max_m_mode, max_n_mode=max_n_mode, curr_pol=booz.curr_pol[surf_idx], curr_tor=booz.curr_tor[surf_idx], iota=booz.iota[surf_idx], grid=grid, use_jax=True, skip_mask=skip_fourier_mask, ) env = RhsEnv( splines=surface.splines, grid=grid, eta=jnp.array([0.0]), bmod0=surface.bmref, iota=booz.iota[surf_idx], curr_pol=booz.curr_pol[surf_idx], curr_tor=booz.curr_tor[surf_idx], ) out = flint_bo_jax(surface, params, env, nfp=booz.nfp, rt0=rt0) if control.ref_swi == 1: b_ref = bmref_g r_ref = rt0 elif control.ref_swi == 2: b_ref = surface.bmref r_ref = rt0 else: raise ValueError(f"Unsupported ref_swi: {control.ref_swi}") scale = (b_ref / surface.bmref) ** 2 * (r_ref / rt0) ** 2 epstot = out["epstot"] * scale epspar = out["epspar"] * scale return ( epstot, epspar, out["ctrone"], out["ctrtot"], out["bareph"], out["barept"], out["yps"], out["drdpsi"], surface.bmref, booz.es[surf_idx], booz.iota[surf_idx], b_ref, r_ref, ) ( epstot, epspar, ctrone, ctrtot, bareph, barept, yps, drdpsi, bmref, s_vals, iota_vals, b_ref, r_ref, ) = jax.vmap(_solve_surface)(surf_indices_j) dpsi = jnp.concatenate([s_vals[:1], s_vals[1:] - s_vals[:-1]], axis=0) r_eff = jnp.cumsum(drdpsi * dpsi) diagnostics = { "s": s_vals, "r_eff": r_eff, "iota": iota_vals, "b_ref": b_ref, "r_ref": r_ref, "bareph": bareph, "barept": barept, "yps": yps, "flux_index": flux_indices_j, "rational_surface_policy": policy, "max_rational_field_periods": work_limit if work_limit is not None else 0, "approximation_used": jnp.zeros_like(s_vals, dtype=bool), } return NeoOutputs( eps_eff=epstot, eps_par=epspar, eps_tot=epstot, ctr_one=ctrone, ctr_tot=ctrtot, diagnostics=diagnostics, )
def _env_flag(name: str) -> bool: value = os.getenv(name, "").strip().lower() return value in {"1", "true", "yes", "on"} def _env_optional_int(name: str) -> int | None: raw = os.getenv(name) if raw is None or raw.strip() == "": return None value = int(raw) return None if value <= 0 else value def _resolve_max_rational_field_periods(value: int | None) -> int | None: if value is not None: return None if value <= 0 else int(value) env_value = _env_optional_int("NEO_JAX_MAX_RATIONAL_FIELD_PERIODS") if env_value is not None: return env_value if os.getenv("NEO_JAX_MAX_RATIONAL_FIELD_PERIODS", "").strip() in {"0", "-1"}: return None return DEFAULT_MAX_RATIONAL_FIELD_PERIODS def _resolve_rational_surface_policy(value: str | None) -> str: policy = (value or os.getenv("NEO_JAX_RATIONAL_SURFACE_POLICY", "error")).strip().lower() if policy in {"", "error"}: return "error" if policy in {"approximate", "approx", "loosen_acc_req"}: return "approximate" raise ValueError( "Unsupported rational_surface_policy=" f"{value!r}. Expected 'error' or 'approximate'." ) def _estimate_rational_work(iota: float, acc_req: float, *, nstep_per: int, npart: int, multra: int) -> Dict[str, float]: abs_iota = abs(float(iota)) safe_iota = max(abs_iota, 1.0e-16) safe_acc = max(float(acc_req), 1.0e-16) field_periods = int(np.ceil(1.0 / safe_acc / safe_iota)) substeps = int(field_periods * max(int(nstep_per), 1)) eta_paths = int(substeps * max(int(npart), 1) * max(int(multra), 1)) return { "abs_iota": abs_iota, "field_periods": field_periods, "substeps": substeps, "eta_paths": eta_paths, } def _surface_preflight_message( *, local_idx: int, total: int, flux_index: int, s_val: float, iota: float, nmodes: int, booz_nfp: int, control: ControlParams, bmref: float, b_min: float, b_max: float, work: Dict[str, float], work_limit: int | None, policy: str, approximation_note: str | None = None, ) -> list[str]: lines = [ ( f"NEO_JAX: surface {local_idx + 1}/{total} index={flux_index} " f"s={s_val:.6f} sqrt(s)={np.sqrt(max(s_val, 0.0)):.6f} iota={iota:.6e}" ), ( "NEO_JAX: resolution " f"theta_n={control.theta_n} phi_n={control.phi_n} npart={control.npart} " f"multra={control.multra} nstep_per={control.nstep_per} " f"nstep_min={control.nstep_min} nstep_max={control.nstep_max}" ), ( "NEO_JAX: geometry " f"nfp={booz_nfp} nmodes={nmodes} " f"B00={bmref:.6e} Bmin={b_min:.6e} Bmax={b_max:.6e}" ), ( "NEO_JAX: preflight " f"approx_rational_field_periods={int(work['field_periods'])} " f"approx_substeps={int(work['substeps'])} " f"approx_eta_paths={int(work['eta_paths'])} " f"limit={work_limit if work_limit is not None else 'disabled'} " f"policy={policy}" ), ] if approximation_note is not None: lines.append(f"NEO_JAX: approximation {approximation_note}") return lines def _raise_rational_work_limit( *, flux_index: int, s_val: float, iota: float, work: Dict[str, float], work_limit: int, ) -> None: raise RuntimeError( "NEO_JAX aborted before integration because the estimated rational-surface " f"correction is too large on surface index {flux_index} (s={s_val:.6f}, " f"iota={iota:.6e}). Estimated field periods={int(work['field_periods'])}, " f"substeps={int(work['substeps'])}, eta-path evaluations~={int(work['eta_paths'])}, " f"which exceeds max_rational_field_periods={work_limit}. " "This is not an infinite loop: the Boozer surface has very small |iota|, " "so the legacy NEO rational correction would require an enormous number of " "field periods. To override the safeguard entirely, set " "NeoConfig(max_rational_field_periods=0) or " "NEO_JAX_MAX_RATIONAL_FIELD_PERIODS=0. For a controlled fallback that " "skips the rational correction after the base integration, set " "NeoConfig(rational_surface_policy='approximate'). To reduce the " "workload, loosen acc_req, reduce the surface set, or avoid near-zero-iota " "surfaces." ) def _write_diagnostic_files( *, events: Sequence[Tuple[int, int, int, float]], meta: Dict[str, float], psi_ind: int, path_prefix: str = "", ) -> None: diag_path = f"{path_prefix}diagnostic.dat" with open(diag_path, "w", encoding="utf-8") as handle: for i_idx, icount, ipa, add_on in events: handle.write( build_fortran_line( (int(i_idx), int(icount), int(ipa)), int_width=8, reals=(float(add_on),), real_width=20, real_digits=10, real_letter="E", ) + "\n" ) add_path = f"{path_prefix}diagnostic_add.dat" with open(add_path, "w", encoding="utf-8") as handle: handle.write( build_fortran_line( (int(psi_ind), int(meta["istepc"]), int(meta["npart"]), int(meta["max_class"])), int_width=8, reals=( float(meta["b_min"]), float(meta["b_max"]), float(meta["bmref"]), float(meta["coeps"]), float(meta["y2"]), float(meta["y3"]), ), real_width=20, real_digits=10, real_letter="E", ) + "\n" ) def _write_diagnostic_bigint( *, bigint: Sequence[float], multra: int, hit_rat: int, nintfp: int, y2: float, y3: float, coeps: float, psi_ind: int, path_prefix: str = "", ) -> None: diag_path = f"{path_prefix}diagnostic_bigint.dat" with open(diag_path, "a", encoding="utf-8") as handle: handle.write( build_fortran_line( (int(psi_ind), int(multra), int(hit_rat), int(nintfp)), int_width=8, reals=(float(y2), float(y3), float(coeps), *[float(val) for val in bigint]), real_width=20, real_digits=10, real_letter="E", ) + "\n" ) class DiagnosticLogger: def __init__(self, *, path_prefix: str = "") -> None: self.path_prefix = path_prefix self.diagnostic_path = f"{path_prefix}diagnostic.dat" self.trap_path = f"{path_prefix}diagnostic_first_trap.dat" self.istepc = 0 self.max_class = 0 self.first_trap_written = False self.snapshot_written = False with open(self.diagnostic_path, "w", encoding="utf-8") as handle: handle.write("") def callback(self, event_mask, icount, ipa, add_on) -> None: mask = np.asarray(event_mask) if not mask.any(): return icount_np = np.asarray(icount) ipa_np = np.asarray(ipa) add_on_np = np.asarray(add_on) idxs = np.nonzero(mask)[0] with open(self.diagnostic_path, "a", encoding="utf-8") as handle: for idx in idxs: self.istepc += 1 ipa_val = int(ipa_np[idx]) self.max_class = max(self.max_class, ipa_val) handle.write( build_fortran_line( (int(idx + 1), int(icount_np[idx]), ipa_val), int_width=8, reals=(float(add_on_np[idx]),), real_width=20, real_digits=10, real_letter="E", ) + "\n" ) def write_add(self, *, psi_ind: int, npart: int, meta: Dict[str, float]) -> None: add_path = f"{self.path_prefix}diagnostic_add.dat" with open(add_path, "w", encoding="utf-8") as handle: handle.write( build_fortran_line( (int(psi_ind), int(self.istepc), int(npart), int(self.max_class)), int_width=8, reals=( float(meta["b_min"]), float(meta["b_max"]), float(meta["bmref"]), float(meta["coeps"]), float(meta["y2"]), float(meta["y3"]), ), real_width=20, real_digits=10, real_letter="E", ) + "\n" ) def write_bigint( self, *, psi_ind: int, multra: int, hit_rat: int, nintfp: int, y2: float, y3: float, coeps: float, bigint: Sequence[float], ) -> None: _write_diagnostic_bigint( bigint=bigint, multra=multra, hit_rat=hit_rat, nintfp=nintfp, y2=y2, y3=y3, coeps=coeps, psi_ind=psi_ind, path_prefix=self.path_prefix, ) def trap_callback(self, event_mask, isw, iswst, p_i, p_h, icount, ipa, phi, j, step_index) -> None: if self.first_trap_written: return mask = np.asarray(event_mask) if not mask.any(): return idxs = np.nonzero(mask)[0] first_idx = int(idxs[0]) isw_np = np.asarray(isw) iswst_np = np.asarray(iswst) p_i_np = np.asarray(p_i) p_h_np = np.asarray(p_h) icount_np = np.asarray(icount) ipa_np = np.asarray(ipa) with open(self.trap_path, "w", encoding="utf-8") as handle: handle.write(f"# first_event_index={first_idx + 1}\n") handle.write(f"# phi={float(phi):.16e} n={int(step_index)} j1={int(j)}\n") handle.write("# columns: idx isw iswst icount ipa p_i p_h event_mask\n") for ii in range(p_i_np.shape[0]): handle.write( f"{ii + 1:8d} {int(isw_np[ii]):8d} {int(iswst_np[ii]):8d}" f" {int(icount_np[ii]):8d} {int(ipa_np[ii]):8d}" f" {float(p_i_np[ii]):20.10e} {float(p_h_np[ii]):20.10e}" f" {int(mask[ii]):8d}\n" ) self.first_trap_written = True def snapshot_callback(self, isw, iswst, p_i, p_h, icount, ipa, phi, j, step_index) -> None: if self.snapshot_written: return isw_np = np.asarray(isw) iswst_np = np.asarray(iswst) p_i_np = np.asarray(p_i) p_h_np = np.asarray(p_h) icount_np = np.asarray(icount) ipa_np = np.asarray(ipa) event_mask = (isw_np == 2) & (iswst_np == 1) with open(self.trap_path.replace("diagnostic_first_trap", "diagnostic_snapshot"), "w", encoding="utf-8") as handle: handle.write(f"# phi={float(phi):.16e} n={int(step_index)} j1={int(j)}\n") handle.write("# columns: idx isw iswst icount ipa p_i p_h event_mask\n") for ii in range(p_i_np.shape[0]): handle.write( f"{ii + 1:8d} {int(isw_np[ii]):8d} {int(iswst_np[ii]):8d}" f" {int(icount_np[ii]):8d} {int(ipa_np[ii]):8d}" f" {float(p_i_np[ii]):20.10e} {float(p_h_np[ii]):20.10e}" f" {int(event_mask[ii]):8d}\n" ) self.snapshot_written = True class ConvergenceLogger: def __init__(self) -> None: self._adimax = 0.0 self._aditot = 0.0 self._step_log_path = None step_log = os.getenv("NEO_JAX_WRITE_IPMAX_DEBUG", "").strip() if step_log and step_log.lower() not in {"0", "false", "no", "off"}: self._step_log_path = step_log if step_log not in {"1", "true", "yes", "on"} else "diagnostic_ipmax_jax.dat" Path(self._step_log_path).write_text("", encoding="utf-8") self.rows: list[tuple[float, float, float, float, float]] = [] def callback(self, row) -> None: vals = np.asarray(row, dtype=float).reshape(-1) self.rows.append(tuple(float(v) for v in vals)) def period_callback(self, row) -> None: vals = np.asarray(row, dtype=float).reshape(-1) n_val, epstot, y3, y3npart, y2 = (float(v) for v in vals) ctrone = 0.0 if y2 == 0.0 else self._aditot / y2 self.rows.append((n_val, epstot, y3, y3npart, ctrone)) def step_callback(self, step_index, substep_index, ipmax, isw, ipa, p_i) -> None: ipmax_val = int(np.asarray(ipmax).reshape(())) isw_np = np.asarray(isw, dtype=int).reshape(-1) ipa_np = np.asarray(ipa, dtype=int).reshape(-1) p_i_np = np.asarray(p_i, dtype=float).reshape(-1) adimax_idx = 0 for idx, isw_val in enumerate(isw_np): if isw_val == 2 and ipa_np[idx] == 1: self._adimax = float(p_i_np[idx]) adimax_idx = idx + 1 if ipmax_val == 1: self._aditot += self._adimax if self._step_log_path is not None: with Path(self._step_log_path).open("a", encoding="utf-8") as handle: handle.write( f"{int(np.asarray(step_index).reshape(())):8d} {int(np.asarray(substep_index).reshape(())):8d}" f" {adimax_idx:8d} {self._adimax:20.10e} {self._aditot:20.10e}" f" {float(p_i_np[-1]):20.10e}\n" ) def reset(self) -> None: self._adimax = 0.0 self._aditot = 0.0 self.rows = [] def run_neo_from_boozer( booz: BoozerData, control: ControlParams, *, use_jax: bool = True, progress: bool = False, extension: str | None = None, legacy_mode: bool = False, max_rational_field_periods: int | None = DEFAULT_MAX_RATIONAL_FIELD_PERIODS, rational_surface_policy: str | None = None, ) -> NeoResults: grid = prepare_grids(control.theta_n, control.phi_n, booz.nfp) legacy_writer = LegacyNeoWriter(extension=extension, progress=progress) if legacy_mode else None if legacy_writer is not None: legacy_writer.prepare_run() if control.write_output_files: legacy_writer.write_static_files(booz=booz, grid=grid) if control.calc_cur: legacy_writer.prepare_current_run(control.cur_file) max_m_mode = control.max_m_mode if control.max_m_mode > 0 else int(np.max(np.abs(booz.ixm))) max_n_mode = control.max_n_mode if control.max_n_mode > 0 else int(np.max(np.abs(booz.ixn))) if control.fluxs_arr: if booz.rmnc.shape[0] == len(control.fluxs_arr): surf_indices = list(range(booz.rmnc.shape[0])) else: surf_indices = [i - 1 for i in control.fluxs_arr] else: surf_indices = list(range(booz.rmnc.shape[0])) ref = compute_reference(booz) rt0 = ref["rt0"] bmref_g = ref["bmref_g"] params = FlintParams( npart=control.npart, multra=control.multra, nstep_per=control.nstep_per, nstep_min=control.nstep_min, nstep_max=control.nstep_max, acc_req=control.acc_req, no_bins=control.no_bins, calc_nstep_max=control.calc_nstep_max, ) work_limit = _resolve_max_rational_field_periods(max_rational_field_periods) policy = _resolve_rational_surface_policy(rational_surface_policy) results: List[NeoSurfaceResult] = [] r_eff = 0.0 write_diagnostic = bool(control.write_diagnostic) or _env_flag("NEO_JAX_WRITE_DIAGNOSTIC") write_trap_debug = write_diagnostic or _env_flag("NEO_JAX_WRITE_TRAP_DEBUG") diag_backend = os.getenv("NEO_JAX_DIAGNOSTIC_BACKEND", "python").strip().lower() disable_jit = _env_flag("NEO_JAX_DISABLE_JIT") force_psi1 = _env_flag("NEO_JAX_DIAGNOSTIC_FORCE_PSI1") snapshot_n = os.getenv("NEO_JAX_SNAPSHOT_N") snapshot_j1 = os.getenv("NEO_JAX_SNAPSHOT_J1") diagnostic_snapshot = None if snapshot_n and snapshot_j1: try: diagnostic_snapshot = (int(snapshot_n), int(snapshot_j1)) except ValueError: diagnostic_snapshot = None flint_bo_jax_fn = flint_bo_jax if use_jax and not write_diagnostic and not disable_jit: flint_bo_jax_fn = jax.jit( flint_bo_jax, static_argnames=( "params", "diagnostic_callback", "diagnostic_trap_callback", "diagnostic_snapshot", "diagnostic_snapshot_callback", "convergence_callback", "convergence_period_callback", "convergence_step_callback", "convergence_reset_callback", "strict_parity", "skip_rational_correction", ), ) for local_idx, surf_idx in enumerate(surf_indices): flux_index = control.fluxs_arr[local_idx] if control.fluxs_arr else surf_idx + 1 coeffs = { "rmnc": jnp.asarray(booz.rmnc[surf_idx]), "zmns": jnp.asarray(booz.zmns[surf_idx]), "lmns": jnp.asarray(booz.lmns[surf_idx]), "bmnc": jnp.asarray(booz.bmnc[surf_idx]), } surface = init_surface( grid["theta_arr"], grid["phi_arr"], coeffs, jnp.asarray(booz.ixm), jnp.asarray(booz.ixn), nfp=booz.nfp, max_m_mode=max_m_mode, max_n_mode=max_n_mode, curr_pol=jnp.asarray(booz.curr_pol[surf_idx]), curr_tor=jnp.asarray(booz.curr_tor[surf_idx]), iota=jnp.asarray(booz.iota[surf_idx]), grid=grid, calc_cur=bool(control.calc_cur), ) work = _estimate_rational_work( float(booz.iota[surf_idx]), control.acc_req, nstep_per=control.nstep_per, npart=control.npart, multra=control.multra, ) surface_params = params skip_rational_correction = False approximation_note = None pending_limit_error = False if work_limit is not None and work["field_periods"] > work_limit: if policy == "approximate": skip_rational_correction = True approximation_note = ( "skipping the expensive rational-surface correction after the " f"base integration because the estimated field periods " f"({int(work['field_periods'])}) exceed the limit ({work_limit})" ) else: pending_limit_error = True if progress: for line in _surface_preflight_message( local_idx=local_idx, total=len(surf_indices), flux_index=flux_index, s_val=float(booz.es[surf_idx]), iota=float(booz.iota[surf_idx]), nmodes=int(len(booz.ixm)), booz_nfp=int(booz.nfp), control=control, bmref=float(surface.bmref), b_min=float(surface.b_min), b_max=float(surface.b_max), work=work, work_limit=work_limit, policy=policy, approximation_note=approximation_note, ): print(line) if pending_limit_error: _raise_rational_work_limit( flux_index=flux_index, s_val=float(booz.es[surf_idx]), iota=float(booz.iota[surf_idx]), work=work, work_limit=work_limit, ) if legacy_writer is not None and control.write_output_files: legacy_writer.write_surface_files(surface.fields) env = RhsEnv( splines=surface.splines, grid=grid, eta=jnp.array([0.0]), bmod0=surface.bmref, iota=jnp.asarray(booz.iota[surf_idx]), curr_pol=jnp.asarray(booz.curr_pol[surf_idx]), curr_tor=jnp.asarray(booz.curr_tor[surf_idx]), ) if use_jax: use_python_loop = False convergence_logger = None use_host_convergence = False if progress: if use_python_loop: print("NEO_JAX: solving epsilon effective with the Python parity backend") else: print("NEO_JAX: solving epsilon effective with the JAX backend") if write_diagnostic and diag_backend == "jax" and not use_python_loop: if progress: print("NEO_JAX: write_diagnostic enabled; using JAX backend with diagnostic callback") logger = DiagnosticLogger() out = flint_bo_jax( surface, surface_params, env, nfp=booz.nfp, rt0=rt0, skip_rational_correction=skip_rational_correction, diagnostic_callback=logger.callback, diagnostic_trap_callback=logger.trap_callback if write_trap_debug else None, diagnostic_snapshot=( (diagnostic_snapshot[0] - 1, diagnostic_snapshot[1] - 1) if diagnostic_snapshot else None ), diagnostic_snapshot_callback=logger.snapshot_callback if diagnostic_snapshot else None, ) jnp.asarray(out["y2"]).block_until_ready() jnp.asarray(out["bigint"]).block_until_ready() etamin = float(surface.b_min / surface.bmref) etamax = float(surface.b_max / surface.bmref) heta = (etamax - etamin) / (surface_params.npart - 1) coeps = float(np.pi * rt0 * rt0 * heta / (8.0 * np.sqrt(2.0))) psi_ind_diag = int(local_idx + 1) if force_psi1 and (control.fluxs_arr is not None and len(control.fluxs_arr) == 1): psi_ind_diag = 1 logger.write_add( psi_ind=psi_ind_diag, npart=surface_params.npart, meta={ "b_min": float(surface.b_min), "b_max": float(surface.b_max), "bmref": float(surface.bmref), "coeps": coeps, "y2": float(out["y2"]), "y3": float(out["y3"]), }, ) logger.write_bigint( psi_ind=psi_ind_diag, multra=surface_params.multra, hit_rat=int(out["hit_rat"]), nintfp=int(out["nintfp"]), y2=float(out["y2"]), y3=float(out["y3"]), coeps=coeps, bigint=np.asarray(out["bigint"]), ) else: if write_diagnostic and progress: print("NEO_JAX: write_diagnostic enabled; using Python-loop backend to emit diagnostic.dat") if write_diagnostic or use_python_loop: out = flint_bo( surface, surface_params, env, nfp=booz.nfp, rt0=rt0, diagnostic=write_diagnostic, diagnostic_trap=write_trap_debug, diagnostic_snapshot=diagnostic_snapshot, collect_convergence=bool(control.write_integrate), skip_rational_correction=skip_rational_correction, ) else: if control.write_integrate: convergence_logger = ConvergenceLogger() use_host_convergence = convergence_logger._step_log_path is not None out = flint_bo_jax_fn( surface, surface_params, env, nfp=booz.nfp, rt0=rt0, skip_rational_correction=skip_rational_correction, convergence_callback=None if convergence_logger is None or use_host_convergence else convergence_logger.callback, convergence_period_callback=None if convergence_logger is None or not use_host_convergence else convergence_logger.period_callback, convergence_step_callback=None if convergence_logger is None or not use_host_convergence else convergence_logger.step_callback, convergence_reset_callback=None if convergence_logger is None else convergence_logger.reset, strict_parity=legacy_mode, ) jnp.asarray(out["y2"]).block_until_ready() if convergence_logger is not None: out["convergence_history"] = convergence_logger.rows else: out = flint_bo( surface, surface_params, env, nfp=booz.nfp, rt0=rt0, diagnostic=write_diagnostic, diagnostic_trap=write_trap_debug, diagnostic_snapshot=diagnostic_snapshot, collect_convergence=bool(control.write_integrate), skip_rational_correction=skip_rational_correction, ) if control.ref_swi == 1: b_ref = bmref_g r_ref = rt0 elif control.ref_swi == 2: b_ref = float(surface.bmref) r_ref = rt0 else: raise ValueError(f"Unsupported ref_swi: {control.ref_swi}") scale = (b_ref / float(surface.bmref)) ** 2 * (r_ref / rt0) ** 2 epstot = float(out["epstot"] * scale) epspar = np.asarray(out["epspar"]) * scale out["rational_surface_policy"] = policy out["max_rational_field_periods"] = work_limit if work_limit is not None else 0 out["requested_acc_req"] = float(control.acc_req) out["effective_acc_req"] = float(surface_params.acc_req) out["approximation_used"] = bool(skip_rational_correction) out["approximation_note"] = approximation_note or "" out["estimated_rational_field_periods"] = int(work["field_periods"]) if skip_rational_correction and progress: print( "NEO_JAX: approximate rational correction enabled " "(skipping expensive rational-surface correction). " "For the full exact legacy correction, rerun with " "max_rational_field_periods=0 or " "NEO_JAX_MAX_RATIONAL_FIELD_PERIODS=0." ) s = float(booz.es[surf_idx]) if local_idx == 0: dpsi = s else: dpsi = s - float(booz.es[surf_indices[local_idx - 1]]) r_eff = r_eff + float(out["drdpsi"] * dpsi) result = NeoSurfaceResult( flux_index=flux_index, s=s, r_eff=r_eff, iota=float(booz.iota[surf_idx]), b_ref=b_ref, r_ref=r_ref, epsilon_effective=epstot, epsilon_effective_by_class=epspar, ctrone=float(out.get("ctrone", 0.0)), ctrtot=float(out.get("ctrtot", 0.0)), bareph=float(out.get("bareph", 0.0)), barept=float(out.get("barept", 0.0)), yps=float(out.get("yps", 0.0)), diagnostics=out, ) if write_diagnostic: diagnostic_events = out.pop("diagnostic_events", None) diagnostic_meta = out.pop("diagnostic_meta", None) psi_ind_diag = int(local_idx + 1) if force_psi1 and (control.fluxs_arr is not None and len(control.fluxs_arr) == 1): psi_ind_diag = 1 if diagnostic_events is not None and diagnostic_meta is not None: _write_diagnostic_files( events=diagnostic_events, meta=diagnostic_meta, psi_ind=psi_ind_diag, ) _write_diagnostic_bigint( bigint=np.asarray(out["bigint"]), multra=params.multra, hit_rat=int(out["hit_rat"]), nintfp=int(out["nintfp"]), y2=float(out["y2"]), y3=float(out["y3"]), coeps=float(diagnostic_meta["coeps"]), psi_ind=psi_ind_diag, ) if progress: print("NEO_JAX: wrote diagnostic.dat, diagnostic_add.dat, diagnostic_bigint.dat") if control.calc_cur: if progress: print("NEO_JAX: solving parallel current") cur_params = CurrentParams( npart_cur=control.npart_cur, alpha_cur=control.alpha_cur, nstep_per=control.nstep_per, nfp=booz.nfp, write_cur_inte=bool(control.write_cur_inte), ) flint_cur_fn = flint_cur_jax if use_jax and not disable_jit: flint_cur_fn = jax.jit(flint_cur_jax, static_argnames=("params",)) current_out = flint_cur_fn(surface, cur_params, env) out["current"] = current_out if legacy_writer is not None: legacy_writer.append_current(cur_file=control.cur_file, psi_ind=local_idx + 1, current_out=current_out) if control.write_cur_inte and current_out.get("history_rows") is not None: legacy_writer.write_current_history(current_out["history_rows"]) if legacy_writer is not None and control.write_integrate: convergence_history = out.get("convergence_history") if convergence_history is not None: legacy_writer.write_conver(convergence_history) if legacy_writer is not None and extension is not None: legacy_writer.append_neolog(psi_ind=local_idx + 1, out=out, epstot=epstot) results.append(result) if progress: print( f"NEO_JAX: epstot={result['epstot']:.6e} reff={result['reff']:.6e} iota={result['iota']:.6e}" ) return NeoResults(results) def run_neo_from_boozmn( boozmn_path: str, control: ControlParams, *, use_jax: bool = True, progress: bool = False, extension: str | None = None, legacy_mode: bool = False, max_rational_field_periods: int | None = DEFAULT_MAX_RATIONAL_FIELD_PERIODS, rational_surface_policy: str | None = None, ) -> NeoResults: booz = read_boozmn( boozmn_path, max_m_mode=control.max_m_mode, max_n_mode=control.max_n_mode, fluxs_arr=control.fluxs_arr, extension=extension, ) return run_neo_from_boozer( booz, control, use_jax=use_jax, progress=progress, extension=extension, legacy_mode=legacy_mode, max_rational_field_periods=max_rational_field_periods, rational_surface_policy=rational_surface_policy, )