Source code for neo_jax.surface

"""Surface initialization and spline construction."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Optional, Tuple

import math
import numpy as np

import jax
import jax.numpy as jnp

from .fourier import derived_quantities, fourier_sums
from .geometry import neo_zeros2d
from .splines import spl2d

Array = jax.Array


[docs] @jax.tree_util.register_pytree_node_class @dataclass(frozen=True) class SurfaceData: b_min: Array b_max: Array theta_bmin: Array phi_bmin: Array theta_bmax: Array phi_bmax: Array bmref: Array fields: Dict[str, Array] splines: Dict[str, Array] def tree_flatten(self): children = (self.b_min, self.b_max, self.theta_bmin, self.phi_bmin, self.theta_bmax, self.phi_bmax, self.bmref, self.fields, self.splines) return children, None @classmethod def tree_unflatten(cls, aux, children): return cls(*children)
def build_splines( fields: Dict[str, Array], theta_int: float, phi_int: float, mt: int, mp: int, calc_cur: bool = False, ) -> Dict[str, Array]: b_spl = spl2d(fields["b"], theta_int, phi_int, mt, mp) g_spl = spl2d(fields["sqrg11"], theta_int, phi_int, mt, mp) k_spl = spl2d(fields["kg"], theta_int, phi_int, mt, mp) p_spl = spl2d(fields["pard"], theta_int, phi_int, mt, mp) splines = {"b_spl": b_spl, "g_spl": g_spl, "k_spl": k_spl, "p_spl": p_spl} if calc_cur and "bqtphi" in fields: splines["q_spl"] = spl2d(fields["bqtphi"], theta_int, phi_int, mt, mp) return splines def _select_extremum_index( b: Array, theta_arr: Array, phi_arr: Array, ixm: Array, ixn: Array, bmnc: Array, max_m_mode: int, max_n_mode: int, *, find_max: bool, ) -> Tuple[Array, Array]: """Select extremum index with a Fortran-like tie-breaker.""" b_np = np.asarray(b) if find_max: extremum = b_np.max() tol = 1.0e-12 * max(1.0, abs(extremum)) mask = b_np >= (extremum - tol) else: extremum = b_np.min() tol = 1.0e-12 * max(1.0, abs(extremum)) mask = b_np <= (extremum + tol) # Candidate indices in Fortran order (theta index varies fastest). mask_t = mask.T flat = mask_t.reshape(-1) candidates = np.nonzero(flat)[0] if candidates.size == 0: idx = int(np.argmax(flat)) if find_max else int(np.argmin(flat)) j, i = np.unravel_index(idx, mask_t.shape) return jnp.asarray(i), jnp.asarray(j) theta_np = np.asarray(theta_arr) phi_np = np.asarray(phi_arr) ixm_np = np.asarray(ixm) ixn_np = np.asarray(ixn) bmnc_np = np.asarray(bmnc) def b_at(theta: float, phi: float) -> float: total = 0.0 for m, n, coeff in zip(ixm_np, ixn_np, bmnc_np): if abs(m) <= max_m_mode and abs(n) <= max_n_mode: total += float(coeff) * math.cos(float(m) * theta - float(n) * phi) return total best_val = None best_i = None best_j = None for idx in candidates: j, i = np.unravel_index(int(idx), mask_t.shape) theta = float(theta_np[i]) phi = float(phi_np[j]) val = b_at(theta, phi) if best_val is None: best_val = val best_i, best_j = i, j continue if find_max: if val > best_val: best_val = val best_i, best_j = i, j else: if val < best_val: best_val = val best_i, best_j = i, j return jnp.asarray(best_i), jnp.asarray(best_j) def _select_extremum_index_jax( b: Array, *, find_max: bool, ) -> Tuple[Array, Array]: """JAX-safe extremum selection (no Python/Numpy tie-breaker).""" # Fortran order: theta index varies fastest, so flatten b.T. flat = jnp.reshape(b.T, (-1,)) idx = jnp.argmax(flat) if find_max else jnp.argmin(flat) j, i = jnp.unravel_index(idx, b.T.shape) return jnp.asarray(i), jnp.asarray(j)
[docs] def init_surface( theta_arr: Array, phi_arr: Array, coeffs: Dict[str, Array], ixm: Array, ixn: Array, nfp: int, max_m_mode: int, max_n_mode: int, curr_pol: Array, curr_tor: Array, iota: Array, grid: Dict[str, float], calc_cur: bool = False, use_jax: bool = False, skip_mask: bool = False, ) -> SurfaceData: """Initialize a single flux surface: Fourier sums, derived fields, splines, B min/max.""" fourier = fourier_sums( theta_arr, phi_arr, coeffs["rmnc"], coeffs["zmns"], coeffs["lmns"], coeffs["bmnc"], ixm, ixn, nfp=nfp, max_m_mode=max_m_mode, max_n_mode=max_n_mode, skip_mask=skip_mask, lasym=bool(coeffs.get("lasym", False)), rmns=coeffs.get("rmns"), zmnc=coeffs.get("zmnc"), lmnc=coeffs.get("lmnc"), bmns=coeffs.get("bmns"), ) derived = derived_quantities(fourier, curr_pol=curr_pol, curr_tor=curr_tor, iota=iota) fields = {**fourier, **derived} splines = build_splines(fields, grid["theta_int"], grid["phi_int"], grid["mt"], grid["mp"], calc_cur) # Find initial min/max from grid, then refine with Newton. # Fortran's MINLOC/MAXLOC traverse arrays in column-major order # (theta index varies fastest). Use exact extrema and pick the # first index in Fortran order when there are ties. b = fields["b"] b_max = jnp.max(b) b_min = jnp.min(b) if use_jax: max_i, max_j = _select_extremum_index_jax(b, find_max=True) min_i, min_j = _select_extremum_index_jax(b, find_max=False) else: max_i, max_j = _select_extremum_index( b, theta_arr, phi_arr, ixm, ixn, coeffs["bmnc"], max_m_mode, max_n_mode, find_max=True, ) min_i, min_j = _select_extremum_index( b, theta_arr, phi_arr, ixm, ixn, coeffs["bmnc"], max_m_mode, max_n_mode, find_max=False, ) theta_bmin = theta_arr[min_i] phi_bmin = phi_arr[min_j] theta_bmax = theta_arr[max_i] phi_bmax = phi_arr[max_j] theta_bmin, phi_bmin, _it_min, _err_min = neo_zeros2d( theta_bmin, phi_bmin, 1.0e-10, 100, splines["b_spl"], grid ) theta_bmax, phi_bmax, _it_max, _err_max = neo_zeros2d( theta_bmax, phi_bmax, 1.0e-10, 100, splines["b_spl"], grid ) # Evaluate B at refined points. from .geometry import neo_eval b_min, *_ = neo_eval( theta_bmin, phi_bmin, splines["b_spl"], splines["g_spl"], splines["k_spl"], splines["p_spl"], splines.get("q_spl"), grid, ) b_max, *_ = neo_eval( theta_bmax, phi_bmax, splines["b_spl"], splines["g_spl"], splines["k_spl"], splines["p_spl"], splines.get("q_spl"), grid, ) bmref = b_max return SurfaceData( b_min=b_min, b_max=b_max, theta_bmin=theta_bmin, phi_bmin=phi_bmin, theta_bmax=theta_bmax, phi_bmax=phi_bmax, bmref=bmref, fields=fields, splines=splines, )