Source code for neo_jax.grids
"""Grid preparation utilities."""
from __future__ import annotations
from typing import Dict
import jax
import jax.numpy as jnp
[docs]
def prepare_grids(theta_n: int, phi_n: int, nfp: int) -> Dict[str, jnp.ndarray | float | int]:
"""Prepare theta/phi grids and spacing.
Mirrors `neo_prep.f90` grid construction.
"""
theta_start = 0.0
theta_end = 2.0 * jnp.pi
phi_start = 0.0
phi_end = 2.0 * jnp.pi / nfp
theta_int = (theta_end - theta_start) / (theta_n - 1)
phi_int = (phi_end - phi_start) / (phi_n - 1)
theta_arr = theta_start + theta_int * jnp.arange(theta_n)
phi_arr = phi_start + phi_int * jnp.arange(phi_n)
def _maybe_float(value):
if isinstance(value, jax.Array):
return value
return float(value)
return {
"theta_start": theta_start,
"theta_end": theta_end,
"phi_start": phi_start,
"phi_end": phi_end,
"theta_int": _maybe_float(theta_int),
"phi_int": _maybe_float(phi_int),
"theta_arr": theta_arr,
"phi_arr": phi_arr,
"theta_n": theta_n,
"phi_n": phi_n,
"mt": 1,
"mp": 1,
}