"""User-facing API helpers for NEO_JAX."""
from __future__ import annotations
from dataclasses import replace
from pathlib import Path
from typing import Sequence
import numpy as np
from .config import NeoConfig
from .control import ControlParams
from .data_models import BoozerData
from .driver import (
_resolve_rational_surface_policy,
run_neo_from_boozer,
run_neo_from_boozmn,
run_neo_from_boozer_jax,
)
from .io import booz_xform_to_boozerdata, read_boozmn, read_boozmn_metadata
from .results import NeoResults
from .data_models import NeoOutputs
def _control_from_config(config: NeoConfig, *, in_file: str = "boozmn", out_file: str = "neo_out") -> ControlParams:
return config.to_control(in_file=in_file, out_file=out_file)
def _surface_s_from_index(index: int, ns_b: int) -> float:
hs = 1.0 / (ns_b - 1)
return (index - 1.5) * hs
def _resolve_surface_indices(
surfaces: Sequence[int | float] | None,
*,
jlist: Sequence[int] | None,
ns_b: int,
) -> list[int] | None:
if surfaces is None:
return None
resolved: list[int] = []
jlist_arr = list(jlist) if jlist is not None else None
if jlist_arr is not None:
s_vals = [_surface_s_from_index(idx, ns_b) for idx in jlist_arr]
for val in surfaces:
if isinstance(val, float) and 0.0 <= val <= 1.0:
if jlist_arr is not None:
best_idx = min(range(len(jlist_arr)), key=lambda i: abs(s_vals[i] - val))
resolved.append(int(jlist_arr[best_idx]))
else:
idx = int(round(val * (ns_b - 1) + 1.5))
idx = max(1, min(ns_b, idx))
resolved.append(idx)
else:
resolved.append(int(val))
return resolved
[docs]
def run_boozmn(
boozmn_path: str | Path,
*,
config: NeoConfig | None = None,
surfaces: Sequence[int] | None = None,
use_jax: bool = True,
progress: bool | None = None,
jax_surface_scan: bool = False,
) -> NeoResults | NeoOutputs:
"""Run NEO_JAX from a boozmn file using a simplified configuration."""
cfg = config or NeoConfig()
if surfaces is not None:
cfg = replace(cfg, surfaces=list(surfaces))
surface_list = cfg.surfaces
if surface_list is not None and any(isinstance(v, float) and 0.0 <= v <= 1.0 for v in surface_list):
meta = read_boozmn_metadata(boozmn_path)
resolved = _resolve_surface_indices(surface_list, jlist=meta["jlist"], ns_b=meta["ns_b"])
cfg = replace(cfg, surfaces=resolved)
ctrl = _control_from_config(cfg)
if progress is None:
progress = cfg.write_progress
rational_surface_policy = _resolve_rational_surface_policy(cfg.rational_surface_policy)
if jax_surface_scan:
booz = read_boozmn(
str(boozmn_path),
max_m_mode=cfg.max_m_mode,
max_n_mode=cfg.max_n_mode,
fluxs_arr=cfg.surfaces,
)
if rational_surface_policy != "error":
return run_neo_from_boozer(
booz,
ctrl,
use_jax=use_jax,
progress=progress,
max_rational_field_periods=cfg.max_rational_field_periods,
rational_surface_policy=rational_surface_policy,
)
return run_neo_from_boozer_jax(
booz,
ctrl,
max_rational_field_periods=cfg.max_rational_field_periods,
rational_surface_policy=rational_surface_policy,
)
return run_neo_from_boozmn(
str(boozmn_path),
ctrl,
use_jax=use_jax,
progress=progress,
max_rational_field_periods=cfg.max_rational_field_periods,
rational_surface_policy=rational_surface_policy,
)
[docs]
def run_boozer(
booz: BoozerData,
*,
config: NeoConfig | None = None,
surfaces: Sequence[int] | None = None,
use_jax: bool = True,
progress: bool | None = None,
jax_surface_scan: bool = False,
) -> NeoResults | NeoOutputs:
"""Run NEO_JAX from a BoozerData object (e.g., booz_xform_jax output)."""
cfg = config or NeoConfig()
if surfaces is not None:
cfg = replace(cfg, surfaces=list(surfaces))
surface_list = cfg.surfaces
if surface_list is not None and any(isinstance(v, float) and 0.0 <= v <= 1.0 for v in surface_list):
s_vals = list(np.asarray(booz.es))
mapped = []
for s_target in surface_list:
if isinstance(s_target, float) and 0.0 <= s_target <= 1.0:
best = min(range(len(s_vals)), key=lambda i: abs(s_vals[i] - s_target))
mapped.append(best + 1)
else:
mapped.append(int(s_target))
cfg = replace(cfg, surfaces=mapped)
ctrl = _control_from_config(cfg)
if progress is None:
progress = cfg.write_progress
rational_surface_policy = _resolve_rational_surface_policy(cfg.rational_surface_policy)
if jax_surface_scan:
if rational_surface_policy != "error":
return run_neo_from_boozer(
booz,
ctrl,
use_jax=use_jax,
progress=progress,
max_rational_field_periods=cfg.max_rational_field_periods,
rational_surface_policy=rational_surface_policy,
)
return run_neo_from_boozer_jax(
booz,
ctrl,
max_rational_field_periods=cfg.max_rational_field_periods,
rational_surface_policy=rational_surface_policy,
)
return run_neo_from_boozer(
booz,
ctrl,
use_jax=use_jax,
progress=progress,
max_rational_field_periods=cfg.max_rational_field_periods,
rational_surface_policy=rational_surface_policy,
)
[docs]
def run_neo(
source: BoozerData | str | Path | object,
*,
config: NeoConfig | None = None,
surfaces: Sequence[int | float] | None = None,
use_jax: bool = True,
progress: bool | None = None,
max_m_mode: int | None = None,
max_n_mode: int | None = None,
jax_surface_scan: bool = False,
) -> NeoResults | NeoOutputs:
"""Run NEO_JAX from a boozmn path, BoozerData, or booz_xform_jax-like object."""
if isinstance(source, (str, Path)):
return run_boozmn(
source,
config=config,
surfaces=surfaces,
use_jax=use_jax,
progress=progress,
jax_surface_scan=jax_surface_scan,
)
if isinstance(source, BoozerData):
return run_boozer(
source,
config=config,
surfaces=surfaces,
use_jax=use_jax,
progress=progress,
jax_surface_scan=jax_surface_scan,
)
return run_booz_xform(
source,
config=config,
surfaces=surfaces,
use_jax=use_jax,
progress=progress,
max_m_mode=max_m_mode,
max_n_mode=max_n_mode,
jax_surface_scan=jax_surface_scan,
)
[docs]
def load_boozmn(
boozmn_path: str | Path,
*,
max_m_mode: int = 0,
max_n_mode: int = 0,
surfaces: Sequence[int] | None = None,
) -> BoozerData:
"""Load a boozmn file into BoozerData for custom workflows."""
return read_boozmn(
str(boozmn_path),
max_m_mode=max_m_mode,
max_n_mode=max_n_mode,
fluxs_arr=surfaces,
)