"""Pipeline helpers for vmec_jax -> booz_xform_jax -> neo_jax."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Callable, Mapping, Sequence
from dataclasses import replace
import numpy as np
from .api import run_neo
from .config import NeoConfig
[docs]
def run_boozer_to_neo(
booz_output: Mapping[str, Any],
*,
config: NeoConfig | None = None,
use_jax: bool = True,
progress: bool | None = None,
) -> Any:
"""Run NEO_JAX from a booz_xform_jax output mapping."""
return run_neo(booz_output, config=config, use_jax=use_jax, progress=progress)
def _resolve_vmec_wout(
vmec_source: Any,
*,
vmec_kwargs: dict | None = None,
fast_bcovar: bool = True,
) -> Any:
vmec_kwargs = vmec_kwargs or {}
if isinstance(vmec_source, (str, Path)):
try:
import vmec_jax as vj
from vmec_jax.driver import run_fixed_boundary, wout_from_fixed_boundary_run
except ImportError as exc: # pragma: no cover - optional dependency
raise ImportError(
"vmec_jax is required for vmec -> booz pipeline. "
"Install vmec_jax or add it to PYTHONPATH."
) from exc
run = run_fixed_boundary(vmec_source, **vmec_kwargs)
return wout_from_fixed_boundary_run(run, include_fsq=False, fast_bcovar=fast_bcovar)
if hasattr(vmec_source, "state") and hasattr(vmec_source, "static"):
try:
from vmec_jax.driver import wout_from_fixed_boundary_run
except ImportError as exc: # pragma: no cover
raise ImportError(
"vmec_jax is required for vmec -> booz pipeline. "
"Install vmec_jax or add it to PYTHONPATH."
) from exc
return wout_from_fixed_boundary_run(vmec_source, include_fsq=False, fast_bcovar=fast_bcovar)
if hasattr(vmec_source, "rmnc"):
return vmec_source
raise TypeError(
"vmec_source must be a vmec_jax FixedBoundaryRun, WoutData, or input path"
)
[docs]
def run_vmec_boozer_neo(
vmec_source: Any,
*,
booz_xform_fn: Callable[..., Mapping[str, Any]] | None = None,
booz_kwargs: dict | None = None,
vmec_kwargs: dict | None = None,
neo_config: NeoConfig | None = None,
use_jax: bool = True,
progress: bool | None = None,
fast_bcovar: bool = True,
) -> Any:
"""Run vmec_jax -> booz_xform_jax -> neo_jax in one workflow.
This requires a JAX-native ``booz_xform_fn`` (for example from
``booz_xform_jax.jax_api``). ``vmec_source`` may be a
``vmec_jax.FixedBoundaryRun``, a ``vmec_jax.WoutData`` object, or a path to
a VMEC input file.
"""
wout = _resolve_vmec_wout(vmec_source, vmec_kwargs=vmec_kwargs, fast_bcovar=fast_bcovar)
booz_kwargs = booz_kwargs or {}
if booz_xform_fn is None:
booz_output = booz_xform_from_vmec_wout(wout, **booz_kwargs)
else:
booz_output = booz_xform_fn(wout, **booz_kwargs)
return run_neo(booz_output, config=neo_config, use_jax=use_jax, progress=progress)
[docs]
def run_vmec_boozer_neo_jax(
vmec_run: Any,
*,
booz_kwargs: dict | None = None,
neo_config: NeoConfig | None = None,
jax_surface_scan: bool = True,
progress: bool | None = None,
) -> Any:
"""JAX-native VMEC -> Boozer -> NEO pipeline using the JAX surface scan."""
booz_kwargs = booz_kwargs or {}
booz_output = booz_xform_from_vmec_state_jax(vmec_run=vmec_run, **booz_kwargs)
return run_neo(
booz_output,
config=neo_config,
use_jax=True,
progress=progress,
jax_surface_scan=jax_surface_scan,
)
[docs]
def build_vmec_boozer_neo_jax(
vmec_run: Any,
*,
booz_kwargs: dict | None = None,
neo_config: NeoConfig | None = None,
jit: bool = True,
):
"""Return a callable `solve(state)` for the JAX-native VMEC→Boozer→NEO path.
This precomputes Boozer constants and surface selections so the returned
function is suitable for repeated calls (and optional JIT).
"""
booz_kwargs = booz_kwargs or {}
cfg = neo_config or NeoConfig()
try:
import jax
import jax.numpy as jnp
from booz_xform_jax.jax_api import prepare_booz_xform_constants, booz_xform_jax_impl
except ImportError as exc: # pragma: no cover
raise ImportError("booz_xform_jax is required for build_vmec_boozer_neo_jax") from exc
try:
from vmec_jax.booz_input import booz_xform_inputs_from_state
from vmec_jax.energy import flux_profiles_from_indata
from vmec_jax.profiles import eval_profiles
from vmec_jax.vmec_tomnsp import vmec_trig_tables
except ImportError as exc: # pragma: no cover
raise ImportError("vmec_jax with booz_input is required for build_vmec_boozer_neo_jax") from exc
from .driver import run_neo_from_boozer_jax
from .io import booz_xform_to_boozerdata_jax
# Precompute static Boozer constants from the current state.
inputs0 = booz_xform_inputs_from_state(
state=vmec_run.state,
static=vmec_run.static,
indata=vmec_run.indata,
signgs=int(vmec_run.signgs),
)
nyq_m = np.asarray(inputs0.xm_nyq)
nyq_n = np.asarray(inputs0.xn_nyq)
nfp_int = int(inputs0.nfp)
mmax = int(np.max(nyq_m)) if nyq_m.size else 0
nmax = int(np.max(np.abs(nyq_n))) // nfp_int if nyq_n.size else 0
trig = vmec_trig_tables(
ntheta=int(vmec_run.static.cfg.ntheta),
nzeta=int(vmec_run.static.cfg.nzeta),
nfp=nfp_int,
mmax=mmax,
nmax=nmax,
lasym=bool(vmec_run.static.cfg.lasym),
dtype=np.asarray(inputs0.rmnc).dtype,
cache=True,
)
mboz_val = int(booz_kwargs.get("mboz") or (np.max(np.asarray(inputs0.xm)) + 1))
nboz_val = int(
booz_kwargs.get("nboz")
or (np.max(np.abs(np.asarray(inputs0.xn))) // int(inputs0.nfp))
)
constants, grids = prepare_booz_xform_constants(
nfp=int(inputs0.nfp),
mboz=mboz_val,
nboz=nboz_val,
asym=bool(vmec_run.static.cfg.lasym),
xm=np.asarray(inputs0.xm),
xn=np.asarray(inputs0.xn),
xm_nyq=np.asarray(inputs0.xm_nyq),
xn_nyq=np.asarray(inputs0.xn_nyq),
)
ns_full = int(inputs0.rmnc.shape[0])
s_half_full = jnp.asarray(0.5 * (vmec_run.static.s[:-1] + vmec_run.static.s[1:]))
s_half_eval = jnp.concatenate([vmec_run.static.s[:1], s_half_full], axis=0)
profiles_half = eval_profiles(vmec_run.indata, s_half_eval)
flux = flux_profiles_from_indata(vmec_run.indata, vmec_run.static.s, signgs=int(vmec_run.signgs))
if cfg.surfaces is None:
surface_indices = None
s_selected = s_half_full
else:
s_vals = list(np.asarray(s_half_full))
surface_indices_list = []
for val in cfg.surfaces:
if isinstance(val, float) and 0.0 <= val <= 1.0:
best = min(range(ns_full), key=lambda i: abs(s_vals[i] - val))
surface_indices_list.append(best)
else:
surface_indices_list.append(int(val) - 1)
surface_indices = jnp.asarray(surface_indices_list, dtype=jnp.int32)
s_selected = jnp.take(s_half_full, surface_indices, axis=0)
control = cfg.to_control()
# Modes already filtered in the JAX pipeline; skip re-masking in Fourier sums.
control = replace(control, max_m_mode=-1, max_n_mode=-1)
# Precompute mode indices for JIT-safe slicing.
xm_b_np = np.asarray(grids.xm_b)
xn_b_np = np.asarray(grids.xn_b)
max_m = int(cfg.max_m_mode) if cfg.max_m_mode > 0 else int(np.max(np.abs(xm_b_np)))
max_n = int(cfg.max_n_mode) if cfg.max_n_mode > 0 else int(np.max(np.abs(xn_b_np)))
mode_indices = np.where((np.abs(xm_b_np) <= max_m) & (np.abs(xn_b_np) <= max_n))[0]
def _solve(state):
inputs = booz_xform_inputs_from_state(
state=state,
static=vmec_run.static,
indata=vmec_run.indata,
signgs=int(vmec_run.signgs),
trig=trig,
flux=flux,
profiles_half=profiles_half,
)
booz_out = booz_xform_jax_impl(
rmnc=inputs.rmnc,
zmns=inputs.zmns,
lmns=inputs.lmns,
bmnc=inputs.bmnc,
bsubumnc=inputs.bsubumnc,
bsubvmnc=inputs.bsubvmnc,
iota=inputs.iota,
xm=inputs.xm,
xn=inputs.xn,
xm_nyq=inputs.xm_nyq,
xn_nyq=inputs.xn_nyq,
constants=constants,
grids=grids,
bmns=inputs.bmns,
bsubumns=inputs.bsubumns,
bsubvmns=inputs.bsubvmns,
surface_indices=surface_indices,
)
booz_out["s_b"] = s_selected
booz_out["ns_b"] = ns_full
if surface_indices is not None:
booz_out["jlist"] = surface_indices + 1
booz = booz_xform_to_boozerdata_jax(
booz_out,
max_m_mode=cfg.max_m_mode,
max_n_mode=cfg.max_n_mode,
nfp_override=int(inputs0.nfp),
mode_indices=mode_indices,
)
return run_neo_from_boozer_jax(booz, control, skip_fourier_mask=True)
if jit:
_solve = jax.jit(_solve)
return _solve