Source code for neo_jax.geometry

"""Geometry evaluation and root finding for NEO_JAX."""

from __future__ import annotations

from typing import Tuple

import jax
import jax.numpy as jnp

from .splines import eva2d_fd_jax, eva2d_jax, eva2d_sd_jax, poi2d_jax

Array = jax.Array


[docs] def neo_eval( theta: Array, phi: Array, b_spl: Array, g_spl: Array, k_spl: Array, p_spl: Array, q_spl: Array | None, grid: dict, ) -> Tuple[Array, Array, Array, Array, Array]: """Evaluate spline fields at a point. Returns (bval, gval, kval, pval, qval). """ ix, iy, dx, dy, ierr = poi2d_jax( grid["theta_int"], grid["phi_int"], grid["mt"], grid["mp"], grid["theta_start"], grid["theta_end"], grid["phi_start"], grid["phi_end"], theta, phi, ) _ = ierr bval = eva2d_jax(b_spl, ix, iy, dx, dy) gval = eva2d_jax(g_spl, ix, iy, dx, dy) kval = eva2d_jax(k_spl, ix, iy, dx, dy) pval = eva2d_jax(p_spl, ix, iy, dx, dy) if q_spl is None: qval = jnp.array(0.0, dtype=bval.dtype) else: qval = eva2d_jax(q_spl, ix, iy, dx, dy) return bval, gval, kval, pval, qval
[docs] def neo_bderiv( theta: Array, phi: Array, b_spl: Array, grid: dict, ) -> Tuple[Array, Array, Array, Array, Array, Array]: """Compute first and second derivatives of B using spline coefficients.""" ix, iy, dx, dy, ierr = poi2d_jax( grid["theta_int"], grid["phi_int"], grid["mt"], grid["mp"], grid["theta_start"], grid["theta_end"], grid["phi_start"], grid["phi_end"], theta, phi, ) _ = ierr fderiv = eva2d_fd_jax(b_spl, ix, iy, dx, dy) sderiv = eva2d_sd_jax(b_spl, ix, iy, dx, dy) f = fderiv[0] g = fderiv[1] dfdx = sderiv[0] dfdy = sderiv[1] dgdx = sderiv[1] dgdy = sderiv[2] return f, g, dfdx, dfdy, dgdx, dgdy
[docs] def neo_zeros2d( theta: Array, phi: Array, eps: float, iter_ma: int, b_spl: Array, grid: dict, ) -> Tuple[Array, Array, Array, Array]: """Newton solver for finding extrema of B. Returns (theta, phi, iter, error) where error=0 on success. """ f, g, dfdx, dfdy, dgdx, dgdy = neo_bderiv(theta, phi, b_spl, grid) def cond(state): _, _, _, _, _, _, _, _, it, converged = state return jnp.logical_and(it < iter_ma, jnp.logical_not(converged)) def body(state): x, y, f, g, dfdx, dfdy, dgdx, dgdy, it, _ = state det = dfdx * dgdy - dfdy * dgdx x_n = x + (dfdy * g - f * dgdy) / det y_n = y + (f * dgdx - dfdx * g) / det f_n, g_n, dfdx_n, dfdy_n, dgdx_n, dgdy_n = neo_bderiv(x_n, y_n, b_spl, grid) x_err = jnp.where(jnp.abs(x_n) > eps, jnp.abs((x_n - x) / x_n), jnp.abs(x_n - x)) y_err = jnp.where(jnp.abs(y_n) > eps, jnp.abs((y_n - y) / y_n), jnp.abs(y_n - y)) max_err = jnp.maximum(x_err, y_err) converged = max_err < eps return (x_n, y_n, f_n, g_n, dfdx_n, dfdy_n, dgdx_n, dgdy_n, it + 1, converged) init_state = (theta, phi, f, g, dfdx, dfdy, dgdx, dgdy, 0, False) theta_f, phi_f, *_rest, it, converged = jax.lax.while_loop(cond, body, init_state) error = jnp.where(converged, 0, 1) return theta_f, phi_f, it, error