Source code for neo_jax.io
"""I/O helpers for NEO_JAX."""
from __future__ import annotations
from pathlib import Path
from typing import Optional, Sequence
import math
import numpy as np
try: # Optional JAX support for end-to-end pipelines
import jax
import jax.numpy as jnp
_JAX_AVAILABLE = True
except Exception: # pragma: no cover - optional dependency
jax = None # type: ignore
jnp = None # type: ignore
_JAX_AVAILABLE = False
from .data_models import BoozerData
try:
import netCDF4 # type: ignore
except Exception: # pragma: no cover - optional dependency
netCDF4 = None
def _require_netcdf4() -> None:
if netCDF4 is None:
raise ImportError("netCDF4 is required to read boozmn files")
[docs]
def resolve_control_path(extension: Optional[str] = None) -> Path:
"""Resolve NEO control file paths following xneo conventions."""
if not extension:
path = Path("neo.in")
if path.exists():
return path
raise FileNotFoundError("neo.in not found")
candidates = [
Path(f"neo_param.{extension}"),
Path("neo_param.in"),
Path(f"neo_in.{extension}"),
]
for candidate in candidates:
if candidate.exists():
return candidate
raise FileNotFoundError(f"No control file found for extension '{extension}'")
def _extension_candidates(base: str | Path, extension: str) -> list[Path]:
base_str = str(base)
candidates: list[str] = []
if base_str in extension:
candidates.append(extension)
else:
if extension.startswith((".", "_")):
candidates.append(f"{base_str}{extension}")
else:
candidates.append(f"{base_str}_{extension}")
candidates.append(f"{base_str}.{extension}")
expanded: list[Path] = []
for cand in candidates:
path = Path(cand)
expanded.append(path)
if path.suffix != ".nc":
expanded.append(path.with_suffix(".nc"))
return expanded
[docs]
def resolve_boozmn_path(base: str | Path, extension: str | None = None) -> Path:
"""Resolve a boozmn file from a base name and optional extension."""
path = Path(base)
if path.exists():
return path
if extension:
for candidate in _extension_candidates(path, extension):
if candidate.exists():
return candidate
if path.suffix != ".nc":
nc_path = path.with_suffix(".nc")
if nc_path.exists():
return nc_path
if path.name != "boozmn":
for candidate in (Path("boozmn"), Path("boozmn.nc")):
if candidate.exists():
return candidate
raise FileNotFoundError(f"Boozmn file not found: {base}")
[docs]
def read_boozmn_metadata(path: str | Path) -> dict:
"""Read minimal metadata (ns_b, jlist) from a boozmn file."""
_require_netcdf4()
booz_path = resolve_boozmn_path(path, None)
with netCDF4.Dataset(booz_path) as ds: # type: ignore[union-attr]
ns_b = int(ds.variables["ns_b"][:])
if "jlist" in ds.variables:
jlist = np.array(ds.variables["jlist"][:], dtype=int).tolist()
else:
jlist = list(range(1, ns_b + 1))
return {"ns_b": ns_b, "jlist": jlist}
def _transpose_if_needed(arr: np.ndarray, pack_len: int) -> np.ndarray:
if arr.shape[0] == pack_len:
return arr
if arr.shape[1] == pack_len:
return arr.T
raise ValueError("Unexpected boozmn array shape")
def _select_modes(ixm: np.ndarray, ixn: np.ndarray, max_m: int, max_n: int):
if _JAX_AVAILABLE and isinstance(ixm, jax.Array): # type: ignore[arg-type]
return (jnp.abs(ixm) <= max_m) & (jnp.abs(ixn) <= max_n) # type: ignore[union-attr]
return (np.abs(ixm) <= max_m) & (np.abs(ixn) <= max_n)
[docs]
def read_boozmn(
path: str | Path,
*,
max_m_mode: int = 0,
max_n_mode: int = 0,
fluxs_arr: Optional[Sequence[int]] = None,
extension: str | None = None,
) -> BoozerData:
"""Read a boozmn netCDF file and return BoozerData.
This function follows the packing conventions used in STELLOPT's
read_boozer_mod, but only retains the surfaces requested.
"""
_require_netcdf4()
booz_path = resolve_boozmn_path(path, extension)
with netCDF4.Dataset(booz_path) as ds: # type: ignore[union-attr]
nfp = int(ds.variables["nfp_b"][:])
ns_b = int(ds.variables["ns_b"][:])
mboz_b = int(ds.variables["mboz_b"][:])
nboz_b = int(ds.variables["nboz_b"][:])
ixm_b = np.array(ds.variables["ixm_b"][:], dtype=int)
ixn_b = np.array(ds.variables["ixn_b"][:], dtype=int)
iota_b = np.array(ds.variables["iota_b"][:], dtype=float)
buco_b = np.array(ds.variables["buco_b"][:], dtype=float)
bvco_b = np.array(ds.variables["bvco_b"][:], dtype=float)
pres_b = np.array(ds.variables["pres_b"][:], dtype=float) if "pres_b" in ds.variables else None
rmnc_raw = np.array(ds.variables["rmnc_b"][:], dtype=float)
zmns_raw = np.array(ds.variables["zmns_b"][:], dtype=float)
pmns_raw = np.array(ds.variables["pmns_b"][:], dtype=float)
bmnc_raw = np.array(ds.variables["bmnc_b"][:], dtype=float)
gmn_raw = np.array(ds.variables["gmn_b"][:], dtype=float) if "gmn_b" in ds.variables else None
if "jlist" in ds.variables:
jlist = np.array(ds.variables["jlist"][:], dtype=int)
else:
jlist = np.arange(1, rmnc_raw.shape[0] + 1, dtype=int)
pack_len = rmnc_raw.shape[0]
rmnc_pack = _transpose_if_needed(rmnc_raw, pack_len)
zmns_pack = _transpose_if_needed(zmns_raw, pack_len)
pmns_pack = _transpose_if_needed(pmns_raw, pack_len)
bmnc_pack = _transpose_if_needed(bmnc_raw, pack_len)
gmn_pack = _transpose_if_needed(gmn_raw, pack_len) if gmn_raw is not None else None
max_m = max_m_mode if max_m_mode > 0 else mboz_b - 1
max_n = max_n_mode if max_n_mode > 0 else nboz_b * nfp
mode_mask = _select_modes(ixm_b, ixn_b, max_m, max_n)
ixm = ixm_b[mode_mask]
ixn = ixn_b[mode_mask]
mode0 = np.where((ixm_b == 0) & (ixn_b == 0))[0]
mode0_idx = int(mode0[0]) if len(mode0) else None
pack_index = {int(surf): idx for idx, surf in enumerate(jlist)}
if fluxs_arr is not None:
surfaces = list(fluxs_arr)
else:
surfaces = list(jlist)
rmnc = []
zmns = []
lmns = []
bmnc = []
es = []
iota = []
curr_pol = []
curr_tor = []
pprime = []
sqrtg00 = []
hs = 1.0 / (ns_b - 1)
for surf in surfaces:
pack_idx = pack_index.get(int(surf))
if pack_idx is None:
raise ValueError(f"Surface {surf} not found in boozmn jlist")
rmnc.append(rmnc_pack[pack_idx, mode_mask])
zmns.append(zmns_pack[pack_idx, mode_mask])
lmns.append(-pmns_pack[pack_idx, mode_mask] * nfp / (2.0 * math.pi))
bmnc.append(bmnc_pack[pack_idx, mode_mask])
es.append((surf - 1.5) * hs)
iota.append(iota_b[surf - 1])
curr_pol.append(bvco_b[surf - 1])
curr_tor.append(buco_b[surf - 1])
if pres_b is not None and surf < ns_b:
pprime.append((pres_b[surf] - pres_b[surf - 1]) / hs)
else:
pprime.append(0.0)
if gmn_pack is not None and mode0_idx is not None:
sqrtg00.append(float(gmn_pack[pack_idx, mode0_idx]))
else:
sqrtg00.append(0.0)
return BoozerData(
rmnc=np.asarray(rmnc),
zmns=np.asarray(zmns),
lmns=np.asarray(lmns),
bmnc=np.asarray(bmnc),
ixm=np.asarray(ixm),
ixn=np.asarray(ixn),
es=np.asarray(es),
iota=np.asarray(iota),
curr_pol=np.asarray(curr_pol),
curr_tor=np.asarray(curr_tor),
nfp=nfp,
pprime=np.asarray(pprime),
sqrtg00=np.asarray(sqrtg00),
)
[docs]
def booz_xform_to_boozerdata(
booz: object,
*,
max_m_mode: int = 0,
max_n_mode: int = 0,
fluxs_arr: Optional[Sequence[int]] = None,
use_jax: bool | None = None,
) -> BoozerData:
"""Convert booz_xform-style arrays into BoozerData.
The input object can be a mapping or an object with attributes matching the
boozmn variable names (e.g., ``rmnc_b``, ``zmns_b``, ``pmns_b``).
"""
def _get(name: str):
if isinstance(booz, dict) and name in booz:
return booz[name]
if hasattr(booz, name):
return getattr(booz, name)
raise KeyError(f"Missing field {name} in Boozer data")
def _asarray(obj, *, dtype=None):
if isinstance(obj, np.ma.MaskedArray):
obj = obj.filled()
return xp.asarray(obj, dtype=dtype)
sample = _get("rmnc_b")
if use_jax is None:
use_jax = _JAX_AVAILABLE and isinstance(sample, jax.Array) # type: ignore[arg-type]
if use_jax and _JAX_AVAILABLE:
return booz_xform_to_boozerdata_jax(
booz,
max_m_mode=max_m_mode,
max_n_mode=max_n_mode,
fluxs_arr=fluxs_arr,
)
xp = jnp if (use_jax and _JAX_AVAILABLE) else np
nfp = int(np.asarray(_get("nfp_b")).squeeze())
ixm_b = _asarray(_get("ixm_b"), dtype=int)
ixn_b = _asarray(_get("ixn_b"), dtype=int)
iota_b = _asarray(_get("iota_b"), dtype=float)
buco_b = _asarray(_get("buco_b"), dtype=float)
bvco_b = _asarray(_get("bvco_b"), dtype=float)
rmnc_raw = _asarray(_get("rmnc_b"), dtype=float)
zmns_raw = _asarray(_get("zmns_b"), dtype=float)
pmns_raw = _asarray(_get("pmns_b"), dtype=float)
bmnc_raw = _asarray(_get("bmnc_b"), dtype=float)
if rmnc_raw.shape[0] == ixm_b.shape[0]:
rmnc_raw = rmnc_raw.T
zmns_raw = zmns_raw.T
pmns_raw = pmns_raw.T
bmnc_raw = bmnc_raw.T
ns_b = rmnc_raw.shape[0]
max_m = max_m_mode if max_m_mode > 0 else int(np.max(np.abs(np.asarray(ixm_b))))
max_n = max_n_mode if max_n_mode > 0 else int(np.max(np.abs(np.asarray(ixn_b))))
mode_mask = _select_modes(ixm_b, ixn_b, max_m, max_n)
ixm = ixm_b[mode_mask]
ixn = ixn_b[mode_mask]
if fluxs_arr:
surfaces = list(fluxs_arr)
else:
surfaces = list(range(1, ns_b + 1))
rmnc = []
zmns = []
lmns = []
bmnc = []
es = []
iota = []
curr_pol = []
curr_tor = []
if "s_b" in getattr(booz, "__dict__", {}) or (isinstance(booz, dict) and "s_b" in booz):
s_vals = np.asarray(_get("s_b"), dtype=float)
else:
ns_full = int(np.asarray(_get("ns_b"))) if (isinstance(booz, dict) and "ns_b" in booz) else ns_b
hs = 1.0 / (ns_full - 1) if ns_full > 1 else 0.0
s_vals = np.array([(surf - 1.5) * hs for surf in range(1, ns_b + 1)], dtype=float)
for surf in surfaces:
surf_idx = surf - 1
rmnc.append(rmnc_raw[surf_idx, mode_mask])
zmns.append(zmns_raw[surf_idx, mode_mask])
lmns.append(-pmns_raw[surf_idx, mode_mask] * nfp / (2.0 * math.pi))
bmnc.append(bmnc_raw[surf_idx, mode_mask])
es.append(s_vals[surf_idx])
iota.append(iota_b[surf_idx])
curr_pol.append(bvco_b[surf_idx])
curr_tor.append(buco_b[surf_idx])
arr = xp.asarray
return BoozerData(
rmnc=arr(rmnc),
zmns=arr(zmns),
lmns=arr(lmns),
bmnc=arr(bmnc),
ixm=arr(ixm),
ixn=arr(ixn),
es=arr(es),
iota=arr(iota),
curr_pol=arr(curr_pol),
curr_tor=arr(curr_tor),
nfp=nfp,
)
[docs]
def booz_xform_to_boozerdata_jax(
booz: object,
*,
max_m_mode: int = 0,
max_n_mode: int = 0,
fluxs_arr: Optional[Sequence[int]] = None,
nfp_override: int | None = None,
mode_indices: Optional[Sequence[int]] = None,
) -> BoozerData:
"""JAX-friendly conversion from booz_xform outputs to BoozerData."""
if not _JAX_AVAILABLE: # pragma: no cover - optional
raise ImportError("JAX is required for booz_xform_to_boozerdata_jax")
def _get(name: str):
if isinstance(booz, dict) and name in booz:
return booz[name]
if hasattr(booz, name):
return getattr(booz, name)
raise KeyError(f"Missing field {name} in Boozer data")
def _asarray(obj, *, dtype=None):
if isinstance(obj, np.ma.MaskedArray):
obj = obj.filled()
return jnp.asarray(obj, dtype=dtype)
nfp = nfp_override if nfp_override is not None else _asarray(_get("nfp_b")).reshape(())[()]
ixm_b = _asarray(_get("ixm_b"), dtype=jnp.int32)
ixn_b = _asarray(_get("ixn_b"), dtype=jnp.int32)
iota_b = _asarray(_get("iota_b"))
buco_b = _asarray(_get("buco_b"))
bvco_b = _asarray(_get("bvco_b"))
rmnc_raw = _asarray(_get("rmnc_b"))
zmns_raw = _asarray(_get("zmns_b"))
pmns_raw = _asarray(_get("pmns_b"))
bmnc_raw = _asarray(_get("bmnc_b"))
# Ensure surface dimension first (static shape check).
if rmnc_raw.shape[0] == ixm_b.shape[0]:
rmnc_raw = rmnc_raw.T
zmns_raw = zmns_raw.T
pmns_raw = pmns_raw.T
bmnc_raw = bmnc_raw.T
ns_b = rmnc_raw.shape[0]
if mode_indices is not None:
mode_idx = _asarray(mode_indices, dtype=jnp.int32)
ixm = jnp.take(ixm_b, mode_idx, axis=0)
ixn = jnp.take(ixn_b, mode_idx, axis=0)
mode_mask = None
else:
max_m = max_m_mode if max_m_mode > 0 else jnp.max(jnp.abs(ixm_b))
max_n = max_n_mode if max_n_mode > 0 else jnp.max(jnp.abs(ixn_b))
mode_mask = (jnp.abs(ixm_b) <= max_m) & (jnp.abs(ixn_b) <= max_n)
ixm = ixm_b[mode_mask]
ixn = ixn_b[mode_mask]
if fluxs_arr:
surface_indices = jnp.asarray([int(s) - 1 for s in fluxs_arr], dtype=jnp.int32)
else:
surface_indices = jnp.arange(ns_b, dtype=jnp.int32)
rmnc_sel = jnp.take(rmnc_raw, surface_indices, axis=0)
zmns_sel = jnp.take(zmns_raw, surface_indices, axis=0)
pmns_sel = jnp.take(pmns_raw, surface_indices, axis=0)
bmnc_sel = jnp.take(bmnc_raw, surface_indices, axis=0)
if mode_indices is not None:
rmnc = jnp.take(rmnc_sel, mode_idx, axis=1)
zmns = jnp.take(zmns_sel, mode_idx, axis=1)
lmns = -jnp.take(pmns_sel, mode_idx, axis=1) * nfp / (2.0 * math.pi)
bmnc = jnp.take(bmnc_sel, mode_idx, axis=1)
else:
rmnc = rmnc_sel[:, mode_mask]
zmns = zmns_sel[:, mode_mask]
lmns = -pmns_sel[:, mode_mask] * nfp / (2.0 * math.pi)
bmnc = bmnc_sel[:, mode_mask]
iota = jnp.take(iota_b, surface_indices, axis=0)
curr_pol = jnp.take(bvco_b, surface_indices, axis=0)
curr_tor = jnp.take(buco_b, surface_indices, axis=0)
if isinstance(booz, dict) and "s_b" in booz:
s_vals = _asarray(_get("s_b"))
es = jnp.take(s_vals, surface_indices, axis=0)
else:
ns_full = (
int(_asarray(_get("ns_b"))) if (isinstance(booz, dict) and "ns_b" in booz) else int(ns_b)
)
hs = 1.0 / (ns_full - 1) if ns_full > 1 else 0.0
jlist = _asarray(_get("jlist")) if (isinstance(booz, dict) and "jlist" in booz) else (surface_indices + 1)
es = (jlist - 1.5) * hs
return BoozerData(
rmnc=rmnc,
zmns=zmns,
lmns=lmns,
bmnc=bmnc,
ixm=ixm,
ixn=ixn,
es=es,
iota=iota,
curr_pol=curr_pol,
curr_tor=curr_tor,
nfp=int(nfp) if nfp_override is None else int(nfp_override),
)