"""User-friendly configuration for NEO_JAX runs."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Sequence
from .control import ControlParams
[docs]
@dataclass(frozen=True)
class NeoConfig:
"""High-level configuration for NEO_JAX runs.
Surface selections may be specified as:
- Integers (1-based NEO surface indices), or
- Floats in [0, 1], interpreted as normalized toroidal flux ``s``.
``max_rational_field_periods`` is a safeguard for near-zero-iota surfaces.
Set it to ``0`` to disable the guard explicitly.
``rational_surface_policy`` controls what happens if the estimated
rational-surface correction exceeds that limit:
- ``"error"``: fail fast with a detailed diagnostic (default)
- ``"approximate"``: skip the expensive rational-surface correction and
return the base integration result with explicit diagnostics
"""
surfaces: Sequence[int | float] | None = None
theta_n: int = 64
phi_n: int = 64
max_m_mode: int = 0
max_n_mode: int = 0
npart: int = 40
multra: int = 2
acc_req: float = 0.02
no_bins: int = 50
nstep_per: int = 20
nstep_min: int = 200
nstep_max: int = 500
calc_nstep_max: int = 0
max_rational_field_periods: Optional[int] = 100_000
rational_surface_policy: str = "error"
ref_swi: int = 2
write_progress: bool = False
write_diagnostic: bool = False
@classmethod
def from_control(cls, control: ControlParams) -> "NeoConfig":
return cls(
surfaces=control.fluxs_arr,
theta_n=control.theta_n,
phi_n=control.phi_n,
max_m_mode=control.max_m_mode,
max_n_mode=control.max_n_mode,
npart=control.npart,
multra=control.multra,
acc_req=control.acc_req,
no_bins=control.no_bins,
nstep_per=control.nstep_per,
nstep_min=control.nstep_min,
nstep_max=control.nstep_max,
calc_nstep_max=control.calc_nstep_max,
ref_swi=control.ref_swi,
write_progress=bool(control.write_progress),
write_diagnostic=bool(control.write_diagnostic),
)
[docs]
def to_control(self, *, in_file: str = "boozmn", out_file: str = "neo_out") -> ControlParams:
"""Convert to a ControlParams object (for CLI compatibility)."""
return ControlParams(
in_file=in_file,
out_file=out_file,
fluxs_arr=list(self.surfaces) if self.surfaces is not None else None,
theta_n=self.theta_n,
phi_n=self.phi_n,
max_m_mode=self.max_m_mode,
max_n_mode=self.max_n_mode,
npart=self.npart,
multra=self.multra,
acc_req=self.acc_req,
no_bins=self.no_bins,
nstep_per=self.nstep_per,
nstep_min=self.nstep_min,
nstep_max=self.nstep_max,
calc_nstep_max=self.calc_nstep_max,
eout_swi=1,
lab_swi=0,
inp_swi=0,
ref_swi=self.ref_swi,
write_progress=int(self.write_progress),
write_output_files=0,
spline_test=0,
write_integrate=0,
write_diagnostic=int(self.write_diagnostic),
calc_cur=0,
cur_file="neo_cur",
npart_cur=0,
alpha_cur=0.0,
write_cur_inte=0,
)