Source code for neo_jax.results

"""User-friendly result containers for NEO_JAX."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Iterator, Mapping, Sequence

from .data_models import NeoOutputs

import numpy as np


_ALIAS_MAP = {
    "flux_index": "flux_index",
    "surface_index": "flux_index",
    "psi": "s",
    "s": "s",
    "sqrt_s": "sqrt_s",
    "reff": "r_eff",
    "r_eff": "r_eff",
    "iota": "iota",
    "b_ref": "b_ref",
    "bref": "b_ref",
    "r_ref": "r_ref",
    "rref": "r_ref",
    "epstot": "epsilon_effective",
    "epsilon_effective": "epsilon_effective",
    "eps_eff": "epsilon_effective",
    "epspar": "epsilon_effective_by_class",
    "epsilon_effective_by_class": "epsilon_effective_by_class",
    "ctrone": "ctrone",
    "ctrtot": "ctrtot",
    "bareph": "bareph",
    "barept": "barept",
    "yps": "yps",
    "diagnostics": "diagnostics",
}


[docs] @dataclass(frozen=True) class NeoSurfaceResult: """Results for a single flux surface.""" flux_index: int s: float r_eff: float iota: float b_ref: float r_ref: float epsilon_effective: float epsilon_effective_by_class: np.ndarray ctrone: float ctrtot: float bareph: float barept: float yps: float diagnostics: Mapping[str, object] @property def epstot(self) -> float: return self.epsilon_effective @property def sqrt_s(self) -> float: return float(np.sqrt(max(self.s, 0.0))) @property def reff(self) -> float: return self.r_eff @property def psi(self) -> float: return self.s @property def epspar(self) -> np.ndarray: return self.epsilon_effective_by_class def __getitem__(self, key: str): mapped = _ALIAS_MAP.get(key) if mapped is None: raise KeyError(key) return getattr(self, mapped) def get(self, key: str, default=None): try: return self[key] except KeyError: return default def to_dict(self) -> dict: return { "flux_index": self.flux_index, "s": self.s, "psi": self.s, "sqrt_s": self.sqrt_s, "r_eff": self.r_eff, "reff": self.r_eff, "iota": self.iota, "b_ref": self.b_ref, "r_ref": self.r_ref, "epstot": self.epsilon_effective, "epspar": self.epsilon_effective_by_class, "ctrone": self.ctrone, "ctrtot": self.ctrtot, "bareph": self.bareph, "barept": self.barept, "yps": self.yps, "diagnostics": self.diagnostics, }
[docs] class NeoResults(Sequence[NeoSurfaceResult]): """Container for multiple surface results with convenience accessors.""" def __init__(self, results: Iterable[NeoSurfaceResult]): self._results = tuple(results) def __len__(self) -> int: # pragma: no cover - trivial return len(self._results) def __iter__(self) -> Iterator[NeoSurfaceResult]: # pragma: no cover - trivial return iter(self._results) def __getitem__(self, key): if isinstance(key, str): return self._collect(key) return self._results[key] def __getattr__(self, name: str): if name in _ALIAS_MAP: return self._collect(name) raise AttributeError(name) def _collect(self, key: str) -> np.ndarray: mapped = _ALIAS_MAP.get(key) if mapped is None: raise KeyError(key) if mapped == "epsilon_effective_by_class": return np.stack([res.epsilon_effective_by_class for res in self._results]) if mapped == "sqrt_s": return np.sqrt(np.array([res.s for res in self._results])) if mapped == "diagnostics": return [res.diagnostics for res in self._results] return np.array([getattr(res, mapped) for res in self._results]) def to_dicts(self) -> list[dict]: return [res.to_dict() for res in self._results]
def _as_array(value): try: arr = np.asarray(value) except Exception: return None if arr.shape == (): return None return arr
[docs] def neo_outputs_to_results( outputs: NeoOutputs, *, flux_indices: Sequence[int] | None = None, ) -> NeoResults: """Convert JAX-friendly ``NeoOutputs`` into ``NeoResults``.""" diag = outputs.diagnostics or {} eps_eff = np.asarray(outputs.eps_eff) eps_par = np.asarray(outputs.eps_par) ctr_one = np.asarray(outputs.ctr_one) ctr_tot = np.asarray(outputs.ctr_tot) n = int(eps_eff.shape[0]) def _series(key: str, default: float = 0.0) -> np.ndarray: value = diag.get(key, default) arr = np.asarray(value) if arr.shape == (): return np.full((n,), float(arr)) return arr s = _series("s", 0.0) r_eff = _series("r_eff", 0.0) iota = _series("iota", 0.0) b_ref = _series("b_ref", 0.0) r_ref = _series("r_ref", 0.0) bareph = _series("bareph", 0.0) barept = _series("barept", 0.0) yps = _series("yps", 0.0) flux_index = None if flux_indices is not None: flux_index = np.asarray(flux_indices) else: flux_index = diag.get("flux_index") if flux_index is not None: flux_index = np.asarray(flux_index) if flux_index is None or flux_index.shape == (): flux_index = np.arange(1, n + 1, dtype=int) results: list[NeoSurfaceResult] = [] for idx in range(n): surface_diag: dict[str, object] = {} for key, value in diag.items(): arr = _as_array(value) if arr is not None and arr.shape[0] == n: if arr.ndim == 1: surface_diag[key] = float(arr[idx]) else: surface_diag[key] = arr[idx] else: surface_diag[key] = value results.append( NeoSurfaceResult( flux_index=int(flux_index[idx]), s=float(s[idx]), r_eff=float(r_eff[idx]), iota=float(iota[idx]), b_ref=float(b_ref[idx]), r_ref=float(r_ref[idx]), epsilon_effective=float(eps_eff[idx]), epsilon_effective_by_class=np.asarray(eps_par[idx]), ctrone=float(ctr_one[idx]), ctrtot=float(ctr_tot[idx]), bareph=float(bareph[idx]), barept=float(barept[idx]), yps=float(yps[idx]), diagnostics=surface_diag, ) ) return NeoResults(results)