Source code for neo_jax.integrate

"""Integration routines for NEO_JAX."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple

import numpy as np

import jax
import jax.numpy as jnp

from .geometry import neo_eval

Array = jax.Array

NPQ = 4


[docs] @jax.tree_util.register_pytree_node_class @dataclass(frozen=True) class RhsState: isw: Array ipa: Array icount: Array ipmax: Array pard0: Array def tree_flatten(self): return (self.isw, self.ipa, self.icount, self.ipmax, self.pard0), None @classmethod def tree_unflatten(cls, aux, children): return cls(*children)
[docs] @jax.tree_util.register_pytree_node_class @dataclass(frozen=True) class RhsEnv: splines: dict grid: dict eta: Array bmod0: Array iota: Array curr_pol: Array | None = None curr_tor: Array | None = None def tree_flatten(self): return (self.splines, self.grid, self.eta, self.bmod0, self.iota, self.curr_pol, self.curr_tor), None @classmethod def tree_unflatten(cls, aux, children): return cls(*children)
[docs] def rhs_bo1(phi: Array, y: Array, state: RhsState, env: RhsEnv) -> Tuple[Array, RhsState]: """Right-hand side for the field-line ODE (port of rhs_bo1.f90).""" theta = y[0] bmod, gval, geodcu, pardeb, _qval = neo_eval( theta, phi, env.splines["b_spl"], env.splines["g_spl"], env.splines["k_spl"], env.splines["p_spl"], env.splines.get("q_spl"), env.grid, ) bmodm2 = 1.0 / (bmod * bmod) bmodm3 = bmodm2 / bmod bra = bmod / env.bmod0 ipass = jnp.where((pardeb * state.pard0 <= 0) & (pardeb > 0), 1, 0).astype(state.isw.dtype) ipmax = jnp.where( (state.ipmax == 0) & (pardeb * state.pard0 <= 0) & (pardeb < 0), 1, state.ipmax, ).astype(state.isw.dtype) pard0 = pardeb dery = jnp.zeros_like(y) dery = dery.at[0].set(env.iota) dery = dery.at[1].set(bmodm2) dery = dery.at[2].set(bmodm2 * gval) dery = dery.at[3].set(geodcu * bmodm3) eta_pos = env.eta > 0 eta_safe = jnp.where(eta_pos, env.eta, jnp.asarray(1.0, dtype=env.eta.dtype)) inv_eta = jnp.where(eta_pos, 1.0 / eta_safe, 0.0) sqeta = jnp.sqrt(eta_safe) inv_sqeta = jnp.where(eta_pos, 1.0 / sqeta, 0.0) subsq = 1.0 - bra * inv_eta mask = (subsq > 0) & eta_pos safe_subsq = jnp.where(mask, subsq, 0.0) sq = jnp.sqrt(safe_subsq) * bmodm2 p_i = jnp.where(mask, sq, 0.0) p_h = jnp.where(mask, sq * (4.0 / bra - inv_eta) * geodcu * inv_sqeta, 0.0) # Update particle state one_i = jnp.array(1, dtype=state.isw.dtype) two_i = jnp.array(2, dtype=state.isw.dtype) zero_i = jnp.array(0, dtype=state.isw.dtype) isw = jnp.where(mask, one_i, jnp.where(state.isw == 1, two_i, jnp.where(state.isw == 2, two_i, zero_i))) icount = state.icount + mask.astype(state.icount.dtype) ipa = state.ipa + ipass * mask.astype(state.ipa.dtype) dery = dery.at[NPQ : NPQ + env.eta.shape[0]].set(p_i) dery = dery.at[NPQ + env.eta.shape[0] : NPQ + 2 * env.eta.shape[0]].set(p_h) new_state = RhsState(isw=isw, ipa=ipa, icount=icount, ipmax=ipmax, pard0=pard0) return dery, new_state
[docs] def rk4_step(phi: Array, y: Array, state: RhsState, env: RhsEnv, h: float) -> Tuple[Array, Array, RhsState]: """Run a single RK4 step, threading RHS state (port of rk4d_bo1.f90).""" hh = h / 2.0 h6 = h / 6.0 k1, state1 = rhs_bo1(phi, y, state, env) y1 = y + hh * k1 k2, state2 = rhs_bo1(phi + hh, y1, state1, env) y2 = y + hh * k2 k3, state3 = rhs_bo1(phi + hh, y2, state2, env) y3 = y + h * k3 k4, state4 = rhs_bo1(phi + h, y3, state3, env) y_new = y + h6 * (k1 + k4 + 2.0 * (k2 + k3)) phi_new = phi + h return phi_new, y_new, state4
def _process_trapped( state: RhsState, iswst: Array, p_i: Array, p_h: Array, bigint: Array, adimax: Array, multra: int, ) -> Tuple[RhsState, Array, Array, Array, Array, Array]: mask2 = state.isw == 2 m_cl = jnp.clip(state.ipa, 1, multra).astype(state.ipa.dtype) def body(i, carry): bigint_acc, adimax_acc = carry def add_fn(carry): bigint_acc, adimax_acc = carry safe_pi = jnp.where(p_i[i] == 0, jnp.array(1.0, dtype=p_i.dtype), p_i[i]) add_on = (p_h[i] * p_h[i]) / safe_pi * iswst[i] idx = m_cl[i] - 1 bigint_acc = bigint_acc.at[idx].add(add_on) adimax_acc = jnp.where(state.ipa[i] == 1, p_i[i], adimax_acc) return bigint_acc, adimax_acc return jax.lax.cond(mask2[i], add_fn, lambda c: c, carry) bigint, adimax = jax.lax.fori_loop(0, p_i.shape[0], body, (bigint, adimax)) iswst = jnp.where(mask2, 1, iswst) p_h = jnp.where(mask2, 0.0, p_h) p_i = jnp.where(mask2, 0.0, p_i) zero_int = jnp.zeros_like(state.isw) isw = jnp.where(mask2, zero_int, state.isw) icount = jnp.where(mask2, zero_int, state.icount) ipa = jnp.where(mask2, zero_int, state.ipa) state = RhsState(isw, ipa, icount, state.ipmax, state.pard0) return state, iswst, p_i, p_h, bigint, adimax
[docs] @jax.tree_util.register_pytree_node_class @dataclass(frozen=True) class FlintParams: npart: int multra: int nstep_per: int nstep_min: int nstep_max: int acc_req: float no_bins: int calc_nstep_max: int def tree_flatten(self): return ( self.npart, self.multra, self.nstep_per, self.nstep_min, self.nstep_max, self.acc_req, self.no_bins, self.calc_nstep_max, ), None @classmethod def tree_unflatten(cls, aux, children): return cls(*children)
[docs] def flint_bo( surface, params: FlintParams, env: RhsEnv, nfp: int, rt0: float | None = None, *, Rmajor: float | None = None, diagnostic: bool = False, diagnostic_trap: bool = False, diagnostic_trap_path: str = "diagnostic_first_trap.dat", diagnostic_snapshot: tuple[int, int] | None = None, diagnostic_snapshot_path: str = "diagnostic_snapshot.dat", collect_convergence: bool = False, skip_rational_correction: bool = False, ): """Python-loop port of flint_bo.f90 (not yet JIT-optimized).""" if rt0 is None: if Rmajor is None: raise ValueError("Either rt0 or Rmajor must be provided.") rt0 = float(Rmajor) elif Rmajor is not None and float(Rmajor) != float(rt0): raise ValueError("rt0 and Rmajor must match if both are provided.") npart = params.npart multra = params.multra ndim = NPQ + 2 * npart # Particle grids etamin = surface.b_min / surface.bmref etamax = surface.b_max / surface.bmref heta = (etamax - etamin) / (npart - 1) etamin = etamin + heta / 2.0 eta = etamin + heta * jnp.arange(npart) # Override env with surface-specific eta and bmod0 env = RhsEnv( splines=env.splines, grid=env.grid, eta=eta, bmod0=surface.bmref, iota=env.iota, curr_pol=env.curr_pol, curr_tor=env.curr_tor, ) coeps = jnp.pi * rt0 * rt0 * heta / (8.0 * jnp.sqrt(2.0)) if env.curr_pol is None or env.curr_tor is None: j_iota_i = 0.0 else: j_iota_i = env.curr_pol + env.iota * env.curr_tor # Initial state y = jnp.zeros(ndim) y = y.at[0].set(surface.theta_bmax) phi = surface.phi_bmax state = RhsState( isw=jnp.zeros(npart, dtype=jnp.int32), ipa=jnp.zeros(npart, dtype=jnp.int32), icount=jnp.zeros(npart, dtype=jnp.int32), ipmax=jnp.array(0, dtype=jnp.int32), pard0=jnp.array(0.0), ) # Initialize pard0 _bmod, _gval, _geodcu, pard0, _qval = neo_eval( surface.theta_bmax, surface.phi_bmax, env.splines["b_spl"], env.splines["g_spl"], env.splines["k_spl"], env.splines["p_spl"], env.splines.get("q_spl"), env.grid, ) state = RhsState(state.isw, state.ipa, state.icount, state.ipmax, pard0) # Main accumulators bigint = jnp.zeros(multra) adimax = jnp.array(0.0) aditot = jnp.array(0.0) iswst = jnp.zeros(npart, dtype=jnp.int32) # Integration parameters nper = nfp hphi = 2.0 * jnp.pi / (params.nstep_per * nper) nstep_max_c = params.nstep_max exist_first_ratfl = 0 hit_rat = 0 nfp_rat = 0 nfl_rat = 0 delta_theta_rat = 0.0 theta0 = surface.theta_bmax phi0 = surface.phi_bmax theta_d_min = 2.0 * jnp.pi n_iota = 1 m_iota = 1 iota_bar_fp = 0.0 n_gap = 0 diagnostic_events = [] if diagnostic else None max_class = 0 trap_written = False snapshot_written = False convergence_history = [] if collect_convergence else None # Main loop for n in range(1, params.nstep_max + 1): for _j1 in range(1, params.nstep_per + 1): phi, y, state = rk4_step(phi, y, state, env, float(hphi)) # Process trapped particle contributions p_i = y[NPQ : NPQ + npart] p_h = y[NPQ + npart : NPQ + 2 * npart] if diagnostic: mask2 = np.asarray(state.isw == 2) if mask2.any(): iswst_np = np.asarray(iswst) event_mask = mask2 & (iswst_np == 1) if event_mask.any(): icount_np = np.asarray(state.icount) ipa_np = np.asarray(state.ipa) p_i_np = np.asarray(p_i) p_h_np = np.asarray(p_h) safe_pi = np.where(p_i_np == 0, 1.0, p_i_np) add_on_np = (p_h_np * p_h_np) / safe_pi * iswst_np idxs = np.nonzero(event_mask)[0] for idx in idxs: diagnostic_events.append( (int(idx + 1), int(icount_np[idx]), int(ipa_np[idx]), float(add_on_np[idx])) ) max_class = max(max_class, int(np.max(ipa_np[event_mask]))) if diagnostic_trap and not trap_written: first_idx = int(idxs[0]) with open(diagnostic_trap_path, "w", encoding="utf-8") as handle: handle.write(f"# first_event_index={first_idx + 1}\n") handle.write(f"# phi={float(phi):.16e} n={n} j1={_j1}\n") handle.write( "# columns: idx isw iswst icount ipa p_i p_h event_mask\n" ) isw_np = np.asarray(state.isw) for ii in range(p_i_np.shape[0]): handle.write( f"{ii + 1:8d} {int(isw_np[ii]):8d} {int(iswst_np[ii]):8d}" f" {int(icount_np[ii]):8d} {int(ipa_np[ii]):8d}" f" {float(p_i_np[ii]):20.10e} {float(p_h_np[ii]):20.10e}" f" {int(event_mask[ii]):8d}\n" ) trap_written = True if diagnostic_snapshot and not snapshot_written: snap_n, snap_j1 = diagnostic_snapshot if n == snap_n and _j1 == snap_j1: mask2 = np.asarray(state.isw == 2) iswst_np = np.asarray(iswst) event_mask = mask2 & (iswst_np == 1) icount_np = np.asarray(state.icount) ipa_np = np.asarray(state.ipa) p_i_np = np.asarray(p_i) p_h_np = np.asarray(p_h) isw_np = np.asarray(state.isw) with open(diagnostic_snapshot_path, "w", encoding="utf-8") as handle: handle.write(f"# phi={float(phi):.16e} n={n} j1={_j1}\n") handle.write("# columns: idx isw iswst icount ipa p_i p_h event_mask\n") for ii in range(p_i_np.shape[0]): handle.write( f"{ii + 1:8d} {int(isw_np[ii]):8d} {int(iswst_np[ii]):8d}" f" {int(icount_np[ii]):8d} {int(ipa_np[ii]):8d}" f" {float(p_i_np[ii]):20.10e} {float(p_h_np[ii]):20.10e}" f" {int(event_mask[ii]):8d}\n" ) snapshot_written = True state, iswst, p_i, p_h, bigint, adimax = _process_trapped( state, iswst, p_i, p_h, bigint, adimax, multra ) y = y.at[NPQ : NPQ + npart].set(p_i) y = y.at[NPQ + npart : NPQ + 2 * npart].set(p_h) if int(state.ipmax) == 1: aditot = aditot + adimax state = RhsState(state.isw, state.ipa, state.icount, jnp.array(0, dtype=jnp.int32), state.pard0) if collect_convergence and convergence_history is not None: epstot_check = 0.0 for m_cl in range(1, multra + 1): epspar_check = float(coeps * bigint[m_cl - 1] * y[1] / (y[2] * y[2])) epstot_check = epstot_check + epspar_check convergence_history.append( ( float(n), epstot_check, float(y[3]), float(y[NPQ + npart - 1] / y[1]), float(aditot / y[1]), ) ) # Rational surface detection theta = y[0] if n <= params.nstep_min: theta_rs = theta - theta0 if n == 1: theta_iota = theta_rs iota_bar_fp = float(theta_iota / (2.0 * jnp.pi)) m = int(jnp.floor(theta_rs / (2.0 * jnp.pi))) theta_rs = theta_rs - m * 2.0 * jnp.pi theta_d = theta_rs if theta_rs <= jnp.pi else theta_rs - 2.0 * jnp.pi if abs(theta_d) < abs(theta_d_min): theta_d_min = theta_d n_iota = n if theta_d >= 0: m_iota = m else: m_iota = m + 1 if n == params.nstep_min: theta_gap = 2.0 * jnp.pi / n_iota n_gap = int(n_iota * int(abs(theta_gap / theta_d_min))) if n_gap > params.nstep_min: nstep_max_c = n_gap else: nstep_max_c = int(n_gap * jnp.ceil(params.nstep_min / n_gap)) if nstep_max_c > params.nstep_max: hit_rat = 1 nfp_rat = int(jnp.ceil(1.0 / params.acc_req / iota_bar_fp)) if iota_bar_fp != 0 else 0 if nfp_rat % n_iota != 0: nfp_rat = nfp_rat + n_iota - (nfp_rat % n_iota) if nfp_rat >= params.nstep_min: exist_first_ratfl = 1 nstep_max_c = nfp_rat nfl_rat = int(jnp.ceil(params.no_bins / n_iota)) delta_theta_rat = float(theta_gap / (nfl_rat + 1)) if params.calc_nstep_max == 1: hit_rat = 0 if hit_rat == 1 and exist_first_ratfl == 0: break if params.calc_nstep_max == 0 and n == nstep_max_c: break nintfp = n y2 = y[1] y3 = y[2] y4 = y[3] y3npart = y[NPQ + npart - 1] # Rational surface correction if hit_rat == 1 and not skip_rational_correction: if exist_first_ratfl == 0: if collect_convergence and convergence_history is not None: convergence_history = [] bigint = jnp.zeros(multra) adimax = jnp.array(0.0) aditot = jnp.array(0.0) y2 = jnp.array(0.0) y3 = jnp.array(0.0) y4 = jnp.array(0.0) y3npart = jnp.array(0.0) for nfl in range(exist_first_ratfl, nfl_rat + 1): bigint_s = jnp.zeros(multra) adimax_s = jnp.array(0.0) aditot_s = jnp.array(0.0) iswst = jnp.zeros(npart, dtype=jnp.int32) state = RhsState( isw=jnp.zeros(npart, dtype=jnp.int32), ipa=jnp.zeros(npart, dtype=jnp.int32), icount=jnp.zeros(npart, dtype=jnp.int32), ipmax=jnp.array(0, dtype=jnp.int32), pard0=state.pard0, ) phi = phi0 y = jnp.zeros(ndim) theta = theta0 + nfl * delta_theta_rat y = y.at[0].set(theta) for _n in range(1, nfp_rat + 1): # Update pard0 at the start of each field period _bmod, _gval, _geodcu, pard0, _qval = neo_eval( y[0], phi, env.splines["b_spl"], env.splines["g_spl"], env.splines["k_spl"], env.splines["p_spl"], env.splines.get("q_spl"), env.grid, ) state = RhsState(state.isw, state.ipa, state.icount, state.ipmax, pard0) for _j1 in range(1, params.nstep_per + 1): phi, y, state = rk4_step(phi, y, state, env, float(hphi)) p_i = y[NPQ : NPQ + npart] p_h = y[NPQ + npart : NPQ + 2 * npart] state, iswst, p_i, p_h, bigint_s, adimax_s = _process_trapped( state, iswst, p_i, p_h, bigint_s, adimax_s, multra ) y = y.at[NPQ : NPQ + npart].set(p_i) y = y.at[NPQ + npart : NPQ + 2 * npart].set(p_h) if int(state.ipmax) == 1: aditot_s = aditot_s + adimax_s state = RhsState(state.isw, state.ipa, state.icount, jnp.array(0, dtype=jnp.int32), state.pard0) if collect_convergence and convergence_history is not None: epstot_check = 0.0 for m_cl in range(1, multra + 1): epspar_check = float(coeps * bigint_s[m_cl - 1] * y[1] / (y[2] * y[2])) epstot_check = epstot_check + epspar_check convergence_history.append( ( float(nfl * nfp_rat + _n), epstot_check, float(y[3]), float(y[NPQ + npart - 1] / y[1]), float(aditot_s / y[1]), ) ) y2_s = y[1] y3_s = y[2] y4_s = y[3] y3npart_s = y[NPQ + npart - 1] bigint = bigint + bigint_s aditot = aditot + aditot_s y2 = y2 + y2_s y3 = y3 + y3_s y4 = y4 + y4_s y3npart = y3npart + y3npart_s n = nfp_rat * (nfl_rat + 1) # Final results epspar = jnp.zeros(multra) epstot = jnp.array(0.0) for m_cl in range(1, multra + 1): epspar = epspar.at[m_cl - 1].set(coeps * bigint[m_cl - 1] * y2 / (y3 * y3)) epstot = epstot + epspar[m_cl - 1] ctrone = aditot / y2 ctrtot = y3npart / y2 bareph = (jnp.pi * ctrone) ** 2 / 8.0 barept = (jnp.pi * ctrtot) ** 2 / 8.0 drdpsi = y2 / y3 yps = y4 * j_iota_i out = { "epspar": epspar, "epstot": epstot, "ctrone": ctrone, "ctrtot": ctrtot, "bareph": bareph, "barept": barept, "drdpsi": drdpsi, "yps": yps, "y2": y2, "y3": y3, "y4": y4, "y3npart": y3npart, "bigint": bigint, "nintfp": nintfp, "hit_rat": hit_rat, "n_iota": n_iota, "m_iota": m_iota, "n_gap": n_gap, "final_n": n, "nfp_rat": nfp_rat, "nfl_rat": nfl_rat, } if collect_convergence and convergence_history is not None: out["convergence_history"] = convergence_history if diagnostic and diagnostic_events is not None: out["diagnostic_events"] = diagnostic_events out["diagnostic_meta"] = { "istepc": len(diagnostic_events), "max_class": max_class, "b_min": float(surface.b_min), "b_max": float(surface.b_max), "bmref": float(surface.bmref), "coeps": float(coeps), "y2": float(y2), "y3": float(y3), "npart": int(npart), } return out
[docs] def flint_bo_jax( surface, params: FlintParams, env: RhsEnv, nfp: int, rt0: float | None = None, *, Rmajor: float | None = None, diagnostic_callback=None, diagnostic_trap_callback=None, diagnostic_snapshot: tuple[int, int] | None = None, diagnostic_snapshot_callback=None, convergence_callback=None, convergence_period_callback=None, convergence_step_callback=None, convergence_reset_callback=None, strict_parity: bool = False, skip_rational_correction: bool = False, ): """JAX-friendly integration loop with rational-surface correction.""" if rt0 is None: if Rmajor is None: raise ValueError("Either rt0 or Rmajor must be provided.") rt0 = float(Rmajor) elif Rmajor is not None and float(Rmajor) != float(rt0): raise ValueError("rt0 and Rmajor must match if both are provided.") npart = params.npart multra = params.multra ndim = NPQ + 2 * npart # Particle grids etamin = surface.b_min / surface.bmref etamax = surface.b_max / surface.bmref heta = (etamax - etamin) / (npart - 1) etamin = etamin + heta / 2.0 eta = etamin + heta * jnp.arange(npart) env = RhsEnv( splines=env.splines, grid=env.grid, eta=eta, bmod0=surface.bmref, iota=env.iota, curr_pol=env.curr_pol, curr_tor=env.curr_tor, ) coeps = jnp.pi * rt0 * rt0 * heta / (8.0 * jnp.sqrt(2.0)) if env.curr_pol is None or env.curr_tor is None: j_iota_i = 0.0 else: j_iota_i = env.curr_pol + env.iota * env.curr_tor y = jnp.zeros(ndim) y = y.at[0].set(surface.theta_bmax) phi0 = surface.phi_bmax phi = phi0 state = RhsState( isw=jnp.zeros(npart, dtype=jnp.int32), ipa=jnp.zeros(npart, dtype=jnp.int32), icount=jnp.zeros(npart, dtype=jnp.int32), ipmax=jnp.array(0, dtype=jnp.int32), pard0=jnp.array(0.0), ) _bmod, _gval, _geodcu, pard0, _qval = neo_eval( surface.theta_bmax, surface.phi_bmax, env.splines["b_spl"], env.splines["g_spl"], env.splines["k_spl"], env.splines["p_spl"], env.splines.get("q_spl"), env.grid, ) state = RhsState(state.isw, state.ipa, state.icount, state.ipmax, pard0) iswst = jnp.zeros(npart, dtype=jnp.int32) bigint = jnp.zeros(multra) adimax = jnp.array(0.0) aditot = jnp.array(0.0) hphi = 2.0 * jnp.pi / (params.nstep_per * nfp) theta0 = surface.theta_bmax theta_d_min = 2.0 * jnp.pi n_iota = jnp.array(1, dtype=jnp.int32) m_iota = jnp.array(1, dtype=jnp.int32) iota_bar_fp = jnp.array(0.0) nstep_max_c = jnp.array(params.nstep_max, dtype=jnp.int32) hit_rat = jnp.array(0, dtype=jnp.int32) exist_first_ratfl = jnp.array(0, dtype=jnp.int32) nfp_rat = jnp.array(0, dtype=jnp.int32) nfl_rat = jnp.array(0, dtype=jnp.int32) delta_theta_rat = jnp.array(0.0) n_gap = jnp.array(0, dtype=jnp.int32) stop = jnp.array(False) n = jnp.array(0, dtype=jnp.int32) if diagnostic_snapshot is None: snap_n = None snap_j = None else: snap_n = jnp.array(diagnostic_snapshot[0], dtype=jnp.int32) snap_j = jnp.array(diagnostic_snapshot[1], dtype=jnp.int32) def integrate_period(carry, emit_diag: bool, step_index): phi, y, state, iswst, bigint, adimax, aditot = carry def rhs_inline(phi_local: Array, y_local: Array, state_local: RhsState) -> Tuple[Array, RhsState]: """Inline RHS evaluation to help XLA fuse neo_eval + RK4.""" theta = y_local[0] bmod, gval, geodcu, pardeb, _qval = neo_eval( theta, phi_local, env.splines["b_spl"], env.splines["g_spl"], env.splines["k_spl"], env.splines["p_spl"], env.splines.get("q_spl"), env.grid, ) bmodm2 = 1.0 / (bmod * bmod) bmodm3 = bmodm2 / bmod bra = bmod / env.bmod0 ipass = jnp.where( (pardeb * state_local.pard0 <= 0) & (pardeb > 0), 1, 0 ).astype(state_local.isw.dtype) ipmax = jnp.where( (state_local.ipmax == 0) & (pardeb * state_local.pard0 <= 0) & (pardeb < 0), 1, state_local.ipmax, ).astype(state_local.isw.dtype) pard0 = pardeb dery = jnp.zeros_like(y_local) dery = dery.at[0].set(env.iota) dery = dery.at[1].set(bmodm2) dery = dery.at[2].set(bmodm2 * gval) dery = dery.at[3].set(geodcu * bmodm3) eta_pos = env.eta > 0 eta_safe = jnp.where(eta_pos, env.eta, jnp.asarray(1.0, dtype=env.eta.dtype)) inv_eta = jnp.where(eta_pos, 1.0 / eta_safe, 0.0) sqeta = jnp.sqrt(eta_safe) inv_sqeta = jnp.where(eta_pos, 1.0 / sqeta, 0.0) subsq = 1.0 - bra * inv_eta mask = (subsq > 0) & eta_pos safe_subsq = jnp.where(mask, subsq, 0.0) sq = jnp.sqrt(safe_subsq) * bmodm2 p_i = jnp.where(mask, sq, 0.0) p_h = jnp.where(mask, sq * (4.0 / bra - inv_eta) * geodcu * inv_sqeta, 0.0) one_i = jnp.array(1, dtype=state_local.isw.dtype) two_i = jnp.array(2, dtype=state_local.isw.dtype) zero_i = jnp.array(0, dtype=state_local.isw.dtype) isw = jnp.where( mask, one_i, jnp.where( state_local.isw == 1, two_i, jnp.where(state_local.isw == 2, two_i, zero_i), ), ) icount = state_local.icount + mask.astype(state_local.icount.dtype) ipa = state_local.ipa + ipass * mask.astype(state_local.ipa.dtype) dery = dery.at[NPQ : NPQ + env.eta.shape[0]].set(p_i) dery = dery.at[NPQ + env.eta.shape[0] : NPQ + 2 * env.eta.shape[0]].set(p_h) new_state = RhsState(isw=isw, ipa=ipa, icount=icount, ipmax=ipmax, pard0=pard0) return dery, new_state def rk4_step_inline( phi_local: Array, y_local: Array, state_local: RhsState ) -> Tuple[Array, Array, RhsState]: """Inline RK4 to increase fusion with neo_eval.""" hh = hphi / 2.0 h6 = hphi / 6.0 k1, state1 = rhs_inline(phi_local, y_local, state_local) y1 = y_local + hh * k1 k2, state2 = rhs_inline(phi_local + hh, y1, state1) y2 = y_local + hh * k2 k3, state3 = rhs_inline(phi_local + hh, y2, state2) y3 = y_local + hphi * k3 k4, state4 = rhs_inline(phi_local + hphi, y3, state3) y_new = y_local + h6 * (k1 + k4 + 2.0 * (k2 + k3)) phi_new = phi_local + hphi return phi_new, y_new, state4 def inner_step(j, inner): phi, y, state, iswst, bigint, adimax, aditot = inner if strict_parity: phi, y, state = rk4_step(phi, y, state, env, hphi) else: phi, y, state = rk4_step_inline(phi, y, state) p_i = y[NPQ : NPQ + npart] p_h = y[NPQ + npart : NPQ + 2 * npart] if diagnostic_callback is not None and emit_diag: mask2 = state.isw == 2 event_mask = mask2 & (iswst == 1) safe_pi = jnp.where(p_i == 0, jnp.array(1.0, dtype=p_i.dtype), p_i) add_on = (p_h * p_h) / safe_pi * iswst def _emit(_): jax.debug.callback(diagnostic_callback, event_mask, state.icount, state.ipa, add_on, ordered=True) return None _ = jax.lax.cond(jnp.any(event_mask), _emit, lambda _: None, operand=None) if diagnostic_trap_callback is not None and emit_diag: mask2 = state.isw == 2 event_mask = mask2 & (iswst == 1) def _emit_trap(_): jax.debug.callback( diagnostic_trap_callback, event_mask, state.isw, iswst, p_i, p_h, state.icount, state.ipa, phi, j, step_index, ordered=True, ) return None _ = jax.lax.cond(jnp.any(event_mask), _emit_trap, lambda _: None, operand=None) if diagnostic_snapshot_callback is not None and emit_diag and snap_n is not None and snap_j is not None: def _emit_snap(_): jax.debug.callback( diagnostic_snapshot_callback, state.isw, iswst, p_i, p_h, state.icount, state.ipa, phi, j, step_index, ordered=True, ) return None snap_cond = (step_index == snap_n) & (j == snap_j) _ = jax.lax.cond(snap_cond, _emit_snap, lambda _: None, operand=None) if convergence_step_callback is not None: jax.debug.callback( convergence_step_callback, step_index + 1, j + 1, state.ipmax, state.isw, state.ipa, p_i, ordered=True, ) if strict_parity: state, iswst, p_i, p_h, bigint, adimax = _process_trapped( state, iswst, p_i, p_h, bigint, adimax, multra ) else: # Inline trapped-particle update to improve fusion in scan body. mask2 = state.isw == 2 m_cl = jnp.clip(state.ipa, 1, multra).astype(state.ipa.dtype) def body(i, carry): bigint_acc, adimax_acc = carry def add_fn(carry): bigint_acc, adimax_acc = carry safe_pi = jnp.where(p_i[i] == 0, jnp.array(1.0, dtype=p_i.dtype), p_i[i]) add_on = (p_h[i] * p_h[i]) / safe_pi * iswst[i] idx = m_cl[i] - 1 bigint_acc = bigint_acc.at[idx].add(add_on) adimax_acc = jnp.where(state.ipa[i] == 1, p_i[i], adimax_acc) return bigint_acc, adimax_acc return jax.lax.cond(mask2[i], add_fn, lambda c: c, carry) bigint, adimax = jax.lax.fori_loop(0, p_i.shape[0], body, (bigint, adimax)) iswst = jnp.where(mask2, 1, iswst) p_h = jnp.where(mask2, 0.0, p_h) p_i = jnp.where(mask2, 0.0, p_i) zero_int = jnp.zeros_like(state.isw) isw = jnp.where(mask2, zero_int, state.isw) icount = jnp.where(mask2, zero_int, state.icount) ipa = jnp.where(mask2, zero_int, state.ipa) state = RhsState(isw, ipa, icount, state.ipmax, state.pard0) y = y.at[NPQ : NPQ + npart].set(p_i) y = y.at[NPQ + npart : NPQ + 2 * npart].set(p_h) aditot = jnp.where(state.ipmax == 1, aditot + adimax, aditot) ipmax = jnp.where( state.ipmax == 1, jnp.array(0, dtype=state.ipmax.dtype), state.ipmax ) state = RhsState(state.isw, state.ipa, state.icount, ipmax, state.pard0) return (phi, y, state, iswst, bigint, adimax, aditot) return jax.lax.fori_loop(0, params.nstep_per, inner_step, carry) def convergence_row(n_val, bigint_val, y_val, aditot_val): epspar_check = coeps * bigint_val * y_val[1] / (y_val[2] * y_val[2]) epstot_check = jnp.sum(epspar_check) return jnp.asarray( [ n_val.astype(y_val.dtype), epstot_check, y_val[3], y_val[NPQ + npart - 1] / y_val[1], aditot_val / y_val[1], ] ) def emit_convergence(n_val, bigint_val, y_val, aditot_val): if convergence_callback is not None: row = convergence_row(n_val, bigint_val, y_val, aditot_val) jax.debug.callback(convergence_callback, row, ordered=True) if convergence_period_callback is not None: epspar_check = coeps * bigint_val * y_val[1] / (y_val[2] * y_val[2]) epstot_check = jnp.sum(epspar_check) period_row = jnp.asarray( [ n_val.astype(y_val.dtype), epstot_check, y_val[3], y_val[NPQ + npart - 1] / y_val[1], y_val[1], ] ) jax.debug.callback(convergence_period_callback, period_row, ordered=True) def update_theta_min(n_val, theta, theta_d_min, n_iota, m_iota, iota_bar_fp): twopi = 2.0 * jnp.pi def body(_): theta_rs = theta - theta0 iota_bar_fp_new = jnp.where(n_val == 1, theta_rs / twopi, iota_bar_fp) m = jnp.floor(theta_rs / twopi) theta_rs_mod = theta_rs - m * twopi theta_d = jnp.where(theta_rs_mod <= jnp.pi, theta_rs_mod, theta_rs_mod - twopi) update = jnp.abs(theta_d) < jnp.abs(theta_d_min) theta_d_min_new = jnp.where(update, theta_d, theta_d_min) n_iota_new = jnp.where(update, n_val, n_iota) m_iota_new = jnp.where( update, jnp.where(theta_d >= 0, m.astype(jnp.int32), (m + 1).astype(jnp.int32)), m_iota ) return theta_d_min_new, n_iota_new, m_iota_new, iota_bar_fp_new return jax.lax.cond( n_val <= params.nstep_min, body, lambda _: (theta_d_min, n_iota, m_iota, iota_bar_fp), operand=None ) def update_rational( n_val, theta_d_min, n_iota, iota_bar_fp, nstep_max_c, hit_rat, exist_first_ratfl, nfp_rat, nfl_rat, delta_theta_rat, n_gap, ): twopi = 2.0 * jnp.pi def body(_): theta_d_min_safe = jnp.where(theta_d_min == 0.0, 1.0e-12, theta_d_min) theta_gap = twopi / n_iota n_gap_new = n_iota * jnp.floor(jnp.abs(theta_gap / theta_d_min_safe)).astype(jnp.int32) nstep_max_c_new = jnp.where( n_gap_new > params.nstep_min, n_gap_new, n_gap_new * jnp.ceil(params.nstep_min / n_gap_new).astype(jnp.int32), ) hit_rat_new = jnp.where(nstep_max_c_new > params.nstep_max, 1, 0).astype(jnp.int32) nfp_rat_new = jnp.where( (hit_rat_new == 1) & (iota_bar_fp != 0.0), jnp.ceil(1.0 / params.acc_req / iota_bar_fp).astype(jnp.int32), 0, ) nfp_rat_new = jnp.where( (hit_rat_new == 1) & (nfp_rat_new % n_iota != 0), nfp_rat_new + n_iota - (nfp_rat_new % n_iota), nfp_rat_new, ) exist_first_ratfl_new = jnp.where(nfp_rat_new >= params.nstep_min, 1, 0).astype(jnp.int32) nstep_max_c_new = jnp.where(exist_first_ratfl_new == 1, nfp_rat_new, nstep_max_c_new) nfl_rat_new = jnp.where( hit_rat_new == 1, jnp.ceil(params.no_bins / n_iota).astype(jnp.int32), nfl_rat, ) delta_theta_rat_new = jnp.where( hit_rat_new == 1, theta_gap / (nfl_rat_new + 1), delta_theta_rat, ) hit_rat_new = jnp.where(params.calc_nstep_max == 1, 0, hit_rat_new).astype(jnp.int32) return ( nstep_max_c_new, hit_rat_new, exist_first_ratfl_new, nfp_rat_new, nfl_rat_new, delta_theta_rat_new, n_gap_new, ) return jax.lax.cond( n_val == params.nstep_min, body, lambda _: (nstep_max_c, hit_rat, exist_first_ratfl, nfp_rat, nfl_rat, delta_theta_rat, n_gap), operand=None, ) def scan_body(carry, _): ( phi, y, state, iswst, bigint, adimax, aditot, n, theta_d_min, n_iota, m_iota, iota_bar_fp, nstep_max_c, hit_rat, exist_first_ratfl, nfp_rat, nfl_rat, delta_theta_rat, n_gap, stop, ) = carry def do_step(_carry): ( phi, y, state, iswst, bigint, adimax, aditot, n, theta_d_min, n_iota, m_iota, iota_bar_fp, nstep_max_c, hit_rat, exist_first_ratfl, nfp_rat, nfl_rat, delta_theta_rat, n_gap, stop, ) = _carry phi, y, state, iswst, bigint, adimax, aditot = integrate_period( (phi, y, state, iswst, bigint, adimax, aditot), True, n, ) n_new = n + 1 emit_convergence(n_new.astype(y.dtype), bigint, y, aditot) theta = y[0] theta_d_min_new, n_iota_new, m_iota_new, iota_bar_fp_new = update_theta_min( n_new, theta, theta_d_min, n_iota, m_iota, iota_bar_fp ) ( nstep_max_c_new, hit_rat_new, exist_first_ratfl_new, nfp_rat_new, nfl_rat_new, delta_theta_rat_new, n_gap_new, ) = update_rational( n_new, theta_d_min_new, n_iota_new, iota_bar_fp_new, nstep_max_c, hit_rat, exist_first_ratfl, nfp_rat, nfl_rat, delta_theta_rat, n_gap, ) stop_new = jnp.where( (params.calc_nstep_max == 0) & (n_new == nstep_max_c_new), True, False, ) stop_new = jnp.where( (hit_rat_new == 1) & (exist_first_ratfl_new == 0), True, stop_new ) return ( phi, y, state, iswst, bigint, adimax, aditot, n_new, theta_d_min_new, n_iota_new, m_iota_new, iota_bar_fp_new, nstep_max_c_new, hit_rat_new, exist_first_ratfl_new, nfp_rat_new, nfl_rat_new, delta_theta_rat_new, n_gap_new, stop_new, ) return jax.lax.cond(stop, lambda c: c, do_step, carry), None init_carry = ( phi, y, state, iswst, bigint, adimax, aditot, n, theta_d_min, n_iota, m_iota, iota_bar_fp, nstep_max_c, hit_rat, exist_first_ratfl, nfp_rat, nfl_rat, delta_theta_rat, n_gap, stop, ) final_carry, _ = jax.lax.scan(scan_body, init_carry, xs=None, length=params.nstep_max) ( phi, y, state, iswst, bigint, adimax, aditot, n, theta_d_min, n_iota, m_iota, iota_bar_fp, nstep_max_c, hit_rat, exist_first_ratfl, nfp_rat, nfl_rat, delta_theta_rat, n_gap, stop, ) = final_carry y2 = y[1] y3 = y[2] y4 = y[3] y3npart = y[NPQ + npart - 1] def rational_correction(_): if convergence_reset_callback is not None: def _reset(_): jax.debug.callback(convergence_reset_callback, ordered=True) return None _ = jax.lax.cond(exist_first_ratfl == 0, _reset, lambda _: None, operand=None) def reset_accumulators(_): zero_bigint = jnp.zeros_like(bigint) return zero_bigint, jnp.array(0.0), jnp.array(0.0), jnp.array(0.0), jnp.array(0.0), jnp.array(0.0) def keep_accumulators(_): return bigint, aditot, y2, y3, y4, y3npart bigint0, aditot0, y20, y30, y40, y3npart0 = jax.lax.cond( exist_first_ratfl == 0, reset_accumulators, keep_accumulators, operand=None ) nfl_count = jnp.maximum(nfl_rat + 1 - exist_first_ratfl, 0).astype(jnp.int32) def nfl_cond(carry): idx, *_ = carry return idx < nfl_count def nfl_body(carry): idx, bigint_acc, aditot_acc, y2_acc, y3_acc, y4_acc, y3npart_acc = carry nfl = idx + exist_first_ratfl phi_local = phi0 y_local = jnp.zeros(ndim) y_local = y_local.at[0].set(theta0 + nfl * delta_theta_rat) state_local = RhsState( isw=jnp.zeros(npart, dtype=jnp.int32), ipa=jnp.zeros(npart, dtype=jnp.int32), icount=jnp.zeros(npart, dtype=jnp.int32), ipmax=jnp.array(0, dtype=jnp.int32), pard0=jnp.array(0.0), ) iswst_local = jnp.zeros(npart, dtype=jnp.int32) bigint_s = jnp.zeros(multra) adimax_s = jnp.array(0.0) aditot_s = jnp.array(0.0) def n_cond(ncarry): n_idx, *_ = ncarry return n_idx < nfp_rat def n_body(ncarry): n_idx, phi_l, y_l, state_l, iswst_l, bigint_s, adimax_s, aditot_s = ncarry _bmod, _gval, _geodcu, pard0, _qval = neo_eval( y_l[0], phi_l, env.splines["b_spl"], env.splines["g_spl"], env.splines["k_spl"], env.splines["p_spl"], env.splines.get("q_spl"), env.grid, ) state_l = RhsState(state_l.isw, state_l.ipa, state_l.icount, state_l.ipmax, pard0) phi_l, y_l, state_l, iswst_l, bigint_s, adimax_s, aditot_s = integrate_period( (phi_l, y_l, state_l, iswst_l, bigint_s, adimax_s, aditot_s), False, n_idx, ) emit_convergence((nfl * nfp_rat + n_idx + 1).astype(y_l.dtype), bigint_s, y_l, aditot_s) return (n_idx + 1, phi_l, y_l, state_l, iswst_l, bigint_s, adimax_s, aditot_s) n_init = ( jnp.array(0, dtype=jnp.int32), phi_local, y_local, state_local, iswst_local, bigint_s, adimax_s, aditot_s, ) n_final = jax.lax.while_loop(n_cond, n_body, n_init) _, phi_local, y_local, state_local, iswst_local, bigint_s, adimax_s, aditot_s = n_final y2_s = y_local[1] y3_s = y_local[2] y4_s = y_local[3] y3npart_s = y_local[NPQ + npart - 1] return ( idx + 1, bigint_acc + bigint_s, aditot_acc + aditot_s, y2_acc + y2_s, y3_acc + y3_s, y4_acc + y4_s, y3npart_acc + y3npart_s, ) nfl_init = (jnp.array(0, dtype=jnp.int32), bigint0, aditot0, y20, y30, y40, y3npart0) nfl_final = jax.lax.while_loop(nfl_cond, nfl_body, nfl_init) _, bigint_out, aditot_out, y2_out, y3_out, y4_out, y3npart_out = nfl_final return bigint_out, aditot_out, y2_out, y3_out, y4_out, y3npart_out def rational_skip(_): return bigint, aditot, y2, y3, y4, y3npart do_rational = (hit_rat == 1) & (nfp_rat > 0) & jnp.logical_not(jnp.asarray(skip_rational_correction)) bigint, aditot, y2, y3, y4, y3npart = jax.lax.cond( do_rational, rational_correction, rational_skip, operand=None ) nintfp = n epspar = jnp.zeros(multra) epstot = jnp.array(0.0) for m_cl in range(1, multra + 1): epspar = epspar.at[m_cl - 1].set(coeps * bigint[m_cl - 1] * y2 / (y3 * y3)) epstot = epstot + epspar[m_cl - 1] ctrone = aditot / y2 ctrtot = y3npart / y2 bareph = (jnp.pi * ctrone) ** 2 / 8.0 barept = (jnp.pi * ctrtot) ** 2 / 8.0 drdpsi = y2 / y3 yps = y4 * j_iota_i return { "epspar": epspar, "epstot": epstot, "ctrone": ctrone, "ctrtot": ctrtot, "bareph": bareph, "barept": barept, "drdpsi": drdpsi, "yps": yps, "y2": y2, "y3": y3, "y4": y4, "y3npart": y3npart, "bigint": bigint, "nintfp": nintfp, "hit_rat": hit_rat, "n_iota": n_iota, "m_iota": m_iota, "n_gap": n_gap, "final_n": jnp.where(hit_rat == 1, nfp_rat * (nfl_rat + 1), nintfp), "nfp_rat": nfp_rat, "nfl_rat": nfl_rat, }