"""Cubic spline utilities ported from STELLOPT NEO."""
from __future__ import annotations
from typing import Tuple
import jax
import jax.numpy as jnp
Array = jax.Array
[docs]
def splreg(y: Array, h: float) -> Tuple[Array, Array, Array]:
"""Regular (non-periodic) cubic spline coefficients.
Mirrors `splreg.f90`.
"""
n = y.shape[0]
ak1 = 0.0
ak2 = 0.0
am1 = 0.0
am2 = 0.0
al = jnp.zeros(n, dtype=y.dtype)
bt = jnp.zeros(n, dtype=y.dtype)
al = al.at[0].set(ak1)
bt = bt.at[0].set(am1)
c = -4.0 * h
def forward_body(i, state):
al, bt = state
e = -3.0 * ((y[i + 1] - y[i]) - (y[i] - y[i - 1])) / h
c1 = c - al[i - 1] * h
al = al.at[i].set(h / c1)
bt = bt.at[i].set((h * bt[i - 1] + e) / c1)
return al, bt
# i runs 1..n-2 inclusive
al, bt = jax.lax.fori_loop(1, n - 1, forward_body, (al, bt))
ci = jnp.zeros(n, dtype=y.dtype)
if n >= 2:
ci = ci.at[n - 1].set((am2 + ak2 * bt[n - 2]) / (1.0 - al[n - 2] * ak2))
def backward_body(i, ci):
# Fortran: i=1..k, i5 = n - i
# Python index: i5 = n - i - 1
i5 = n - i - 1
ci = ci.at[i5].set(al[i5] * ci[i5 + 1] + bt[i5])
return ci
ci = jax.lax.fori_loop(1, n, backward_body, ci)
bi = jnp.zeros(n, dtype=y.dtype)
di = jnp.zeros(n, dtype=y.dtype)
def coeff_body(i, state):
bi, di = state
bi = bi.at[i].set((y[i + 1] - y[i]) / h - h * (ci[i + 1] + 2.0 * ci[i]) / 3.0)
di = di.at[i].set((ci[i + 1] - ci[i]) / h / 3.0)
return bi, di
bi, di = jax.lax.fori_loop(0, n - 1, coeff_body, (bi, di))
return bi, ci, di
[docs]
def spfper(np1: int, dtype=jnp.float64) -> Tuple[Array, Array, Array]:
"""Helper routine for periodic splines (port of spfper.f90)."""
n = np1 - 1
n1 = n - 1
amx1 = jnp.zeros(np1, dtype=dtype)
amx2 = jnp.zeros(np1, dtype=dtype)
amx3 = jnp.zeros(np1, dtype=dtype)
amx1 = amx1.at[0].set(2.0)
amx2 = amx2.at[0].set(0.5)
amx3 = amx3.at[0].set(0.5)
if np1 > 1:
amx1 = amx1.at[1].set(jnp.sqrt(15.0) / 2.0)
amx2 = amx2.at[1].set(1.0 / amx1[1])
amx3 = amx3.at[1].set(-0.25 / amx1[1])
beta0 = 3.75
def loop_body(i, state):
amx1, amx2, amx3, beta = state
beta = 4.0 - 1.0 / beta
amx1 = amx1.at[i].set(jnp.sqrt(beta))
amx2 = amx2.at[i].set(1.0 / amx1[i])
amx3 = amx3.at[i].set(-amx3[i - 1] / amx1[i] / amx1[i - 1])
return amx1, amx2, amx3, beta
# Fortran loop i=3..n1 => Python i=2..n1-1
if n1 > 2:
amx1, amx2, amx3, _ = jax.lax.fori_loop(2, n1, loop_body, (amx1, amx2, amx3, beta0))
else:
_ = beta0
if n1 >= 1:
amx3 = amx3.at[n1 - 1].set(amx3[n1 - 1] + 1.0 / amx1[n1 - 1])
amx2 = amx2.at[n1 - 1].set(amx3[n1 - 1])
ss = jnp.sum(amx3[: n1] * amx3[: n1]) if n1 > 0 else 0.0
if n >= 1:
amx1 = amx1.at[n - 1].set(jnp.sqrt(4.0 - ss))
return amx1, amx2, amx3
[docs]
def splper(y: Array, h: float) -> Tuple[Array, Array, Array]:
"""Periodic cubic spline coefficients.
Mirrors `splper.f90`.
"""
n = y.shape[0]
bmx = jnp.zeros(n, dtype=y.dtype)
yl = jnp.zeros(n, dtype=y.dtype)
amx1, amx2, amx3 = spfper(n, dtype=y.dtype)
bmx = bmx.at[0].set(1.0e30)
nmx = n - 1
n1 = nmx - 1
n2 = nmx - 2
psi = 3.0 / (h * h)
if nmx >= 1:
bmx = bmx.at[nmx - 1].set((y[nmx] - 2.0 * y[nmx - 1] + y[nmx - 2]) * psi)
bmx = bmx.at[0].set((y[1] - y[0] - y[nmx] + y[nmx - 1]) * psi)
def bmx_body(i, bmx):
# Fortran i=3..nmx => Python i=2..nmx-1
bmx = bmx.at[i - 1].set((y[i] - 2.0 * y[i - 1] + y[i - 2]) * psi)
return bmx
if nmx > 2:
bmx = jax.lax.fori_loop(2, nmx, bmx_body, bmx)
if n1 >= 1:
yl = yl.at[0].set(bmx[0] / amx1[0])
def yl_body(i, yl):
# Fortran i=2..n1 => Python i=1..n1-1
yl = yl.at[i].set((bmx[i] - yl[i - 1] * amx2[i - 1]) / amx1[i])
return yl
if n1 > 1:
yl = jax.lax.fori_loop(1, n1, yl_body, yl)
ss = jnp.sum(yl[:n1] * amx3[:n1])
yl = yl.at[nmx - 1].set((bmx[nmx - 1] - ss) / amx1[nmx - 1])
bmx = bmx.at[nmx - 1].set(yl[nmx - 1] / amx1[nmx - 1])
bmx = bmx.at[n1 - 1].set((yl[n1 - 1] - amx2[n1 - 1] * bmx[nmx - 1]) / amx1[n1 - 1])
def back_body(i, bmx):
# Fortran i=n2..1 => Python i = n2-1 .. 0
idx = n2 - 1 - i
bmx = bmx.at[idx].set(
(yl[idx] - amx3[idx] * bmx[nmx - 1] - amx2[idx] * bmx[idx + 1])
/ amx1[idx]
)
return bmx
if n2 >= 1:
bmx = jax.lax.fori_loop(0, n2, back_body, bmx)
ci = jnp.zeros(n, dtype=y.dtype)
if nmx >= 1:
ci = ci.at[:nmx].set(bmx[:nmx])
bi = jnp.zeros(n, dtype=y.dtype)
di = jnp.zeros(n, dtype=y.dtype)
def coeff_body(i, state):
bi, di = state
bi = bi.at[i].set((y[i + 1] - y[i]) / h - h * (ci[i + 1] + 2.0 * ci[i]) / 3.0)
di = di.at[i].set((ci[i + 1] - ci[i]) / h / 3.0)
return bi, di
if n1 >= 1:
bi, di = jax.lax.fori_loop(0, n1, coeff_body, (bi, di))
if nmx >= 1:
bi = bi.at[nmx - 1].set((y[nmx] - y[nmx - 1]) / h - h * (ci[0] + 2.0 * ci[nmx - 1]) / 3.0)
di = di.at[nmx - 1].set((ci[0] - ci[nmx - 1]) / h / 3.0)
# Fix boundary
bi = bi.at[n - 1].set(bi[0])
ci = ci.at[n - 1].set(ci[0])
di = di.at[n - 1].set(di[0])
return bi, ci, di
[docs]
def spl2d(f: Array, hx: float, hy: float, mx: int, my: int) -> Array:
"""Two-dimensional cubic spline coefficients.
Returns array of shape (4, 4, nx, ny).
"""
nx, ny = f.shape
def spline_x(col):
if mx == 0:
bi, ci, di = splreg(col, hx)
else:
bi, ci, di = splper(col, hx)
return jnp.stack([col, bi, ci, di], axis=0)
stage1 = jax.vmap(spline_x, in_axes=1, out_axes=0)(f) # (ny,4,nx)
stage1 = jnp.transpose(stage1, (1, 2, 0)) # (4,nx,ny)
def spline_y(col):
if my == 0:
bi, ci, di = splreg(col, hy)
else:
bi, ci, di = splper(col, hy)
return jnp.stack([bi, ci, di], axis=0) # (3, ny)
# Flatten (k, i) axes to vmap over them.
data = stage1.reshape((4 * nx, ny))
stage2 = jax.vmap(spline_y, in_axes=0, out_axes=0)(data) # (4*nx,3,ny)
stage2 = stage2.reshape((4, nx, 3, ny))
stage2 = jnp.transpose(stage2, (0, 2, 1, 3)) # (4,3,nx,ny)
spl = jnp.zeros((4, 4, nx, ny), dtype=f.dtype)
spl = spl.at[:, 0, :, :].set(stage1)
spl = spl.at[:, 1:4, :, :].set(stage2)
return spl
[docs]
def poi2d(
hx: float,
hy: float,
mx: int,
my: int,
xmin: float,
xmax: float,
ymin: float,
ymax: float,
x: float,
y: float,
):
"""Pointer calculation for spline evaluation (port of poi2d.f90).
Returns (ix, iy, dx, dy, ierr) where ix, iy are 0-based indices.
"""
ierr = 0
dxx = x - xmin
if mx == 0:
if dxx < 0.0:
return 0, 0, 0.0, 0.0, 1
if x > xmax:
return 0, 0, 0.0, 0.0, 2
else:
dxmax = xmax - xmin
if dxx < 0.0:
dxx = dxx + (1 + int(abs(dxx / dxmax))) * dxmax
elif dxx > dxmax:
dxx = dxx - (int(abs(dxx / dxmax))) * dxmax
x1 = dxx / hx
ix = int(x1)
dx = hx * (x1 - ix)
dyy = y - ymin
if my == 0:
if dyy < 0.0:
return 0, 0, 0.0, 0.0, 3
if y > ymax:
return 0, 0, 0.0, 0.0, 4
else:
dymax = ymax - ymin
if dyy < 0.0:
dyy = dyy + (1 + int(abs(dyy / dymax))) * dymax
elif dyy > dymax:
dyy = dyy - (int(abs(dyy / dymax))) * dymax
y1 = dyy / hy
iy = int(y1)
dy = hy * (y1 - iy)
return ix, iy, dx, dy, ierr
[docs]
def eva2d(spl: Array, ix: int, iy: int, dx: float, dy: float) -> float:
"""Evaluate 2D spline at given cell (port of eva2d.f90)."""
a = []
for l in range(4):
a_l = spl[0, l, ix, iy] + dx * (
spl[1, l, ix, iy] + dx * (spl[2, l, ix, iy] + dx * spl[3, l, ix, iy])
)
a.append(a_l)
a = jnp.stack(a, axis=0)
spval = a[0] + dy * (a[1] + dy * (a[2] + dy * a[3]))
return spval
[docs]
def eva2d_fd(spl: Array, ix: int, iy: int, dx: float, dy: float) -> Array:
"""Evaluate first derivatives of 2D spline (port of eva2d_fd.f90)."""
spval = jnp.zeros((2,), dtype=spl.dtype)
# df/dx
for i in range(1, 4):
muli = 1.0 if i == 1 else dx ** (i - 1)
muli = muli * i
for j in range(4):
mulj = 1.0 if j == 0 else dy ** j
spval = spval.at[0].add(spl[i, j, ix, iy] * muli * mulj)
# df/dy
for i in range(4):
muli = 1.0 if i == 0 else dx ** i
for j in range(1, 4):
mulj = 1.0 if j == 1 else dy ** (j - 1)
mulj = mulj * j
spval = spval.at[1].add(spl[i, j, ix, iy] * muli * mulj)
return spval
[docs]
def eva2d_sd(spl: Array, ix: int, iy: int, dx: float, dy: float) -> Array:
"""Evaluate second derivatives of 2D spline (port of eva2d_sd.f90)."""
spval = jnp.zeros((3,), dtype=spl.dtype)
# d^2f/dx^2
for i in range(2, 4):
muli = 1.0 if i == 2 else dx ** (i - 2)
muli = muli * (i) * (i - 1)
for j in range(4):
mulj = 1.0 if j == 0 else dy ** j
spval = spval.at[0].add(spl[i, j, ix, iy] * muli * mulj)
# d^2f/(dxdy)
for i in range(1, 4):
muli = 1.0 if i == 1 else dx ** (i - 1)
muli = muli * i
for j in range(1, 4):
mulj = 1.0 if j == 1 else dy ** (j - 1)
mulj = mulj * j
spval = spval.at[1].add(spl[i, j, ix, iy] * muli * mulj)
# d^2f/dy^2
for i in range(4):
muli = 1.0 if i == 0 else dx ** i
for j in range(2, 4):
mulj = 1.0 if j == 2 else dy ** (j - 2)
mulj = mulj * j * (j - 1)
spval = spval.at[2].add(spl[i, j, ix, iy] * muli * mulj)
return spval
[docs]
def poi2d_jax(
hx: float,
hy: float,
mx: int,
my: int,
xmin: float,
xmax: float,
ymin: float,
ymax: float,
x: Array,
y: Array,
):
"""JAX-friendly pointer calculation for spline evaluation."""
ierr = jnp.int32(0)
dxx = x - xmin
def handle_x_nonperiodic(_):
ierr_local = ierr
ierr_local = jnp.where(dxx < 0.0, jnp.int32(1), ierr_local)
ierr_local = jnp.where(x > xmax, jnp.int32(2), ierr_local)
return dxx, ierr_local
def handle_x_periodic(_):
dxmax = xmax - xmin
nwrap = jnp.floor(jnp.abs(dxx / dxmax))
dxx_wrapped = jnp.where(dxx < 0.0, dxx + (1.0 + nwrap) * dxmax, dxx)
dxx_wrapped = jnp.where(dxx_wrapped > dxmax, dxx_wrapped - nwrap * dxmax, dxx_wrapped)
return dxx_wrapped, ierr
dxx, ierr = jax.lax.cond(mx == 0, handle_x_nonperiodic, handle_x_periodic, operand=None)
x1 = dxx / hx
ix = jnp.floor(x1).astype(jnp.int32)
dx = hx * (x1 - ix)
dyy = y - ymin
def handle_y_nonperiodic(_):
ierr_local = ierr
ierr_local = jnp.where(dyy < 0.0, jnp.int32(3), ierr_local)
ierr_local = jnp.where(y > ymax, jnp.int32(4), ierr_local)
return dyy, ierr_local
def handle_y_periodic(_):
dymax = ymax - ymin
nwrap = jnp.floor(jnp.abs(dyy / dymax))
dyy_wrapped = jnp.where(dyy < 0.0, dyy + (1.0 + nwrap) * dymax, dyy)
dyy_wrapped = jnp.where(dyy_wrapped > dymax, dyy_wrapped - nwrap * dymax, dyy_wrapped)
return dyy_wrapped, ierr
dyy, ierr = jax.lax.cond(my == 0, handle_y_nonperiodic, handle_y_periodic, operand=None)
y1 = dyy / hy
iy = jnp.floor(y1).astype(jnp.int32)
dy = hy * (y1 - iy)
return ix, iy, dx, dy, ierr
[docs]
def eva2d_jax(spl: Array, ix: Array, iy: Array, dx: Array, dy: Array) -> Array:
"""JAX-friendly spline evaluation."""
coeff = jnp.take(spl, ix, axis=2)
coeff = jnp.take(coeff, iy, axis=2)
a = coeff[0, :] + dx * (coeff[1, :] + dx * (coeff[2, :] + dx * coeff[3, :]))
spval = a[0] + dy * (a[1] + dy * (a[2] + dy * a[3]))
return spval
[docs]
def eva2d_fd_jax(spl: Array, ix: Array, iy: Array, dx: Array, dy: Array) -> Array:
"""JAX-friendly first derivatives of spline."""
coeff = jnp.take(spl, ix, axis=2)
coeff = jnp.take(coeff, iy, axis=2)
# df/dx
sp0 = 0.0
for i in range(1, 4):
muli = (1.0 if i == 1 else dx ** (i - 1)) * i
for j in range(4):
mulj = 1.0 if j == 0 else dy ** j
sp0 = sp0 + coeff[i, j] * muli * mulj
# df/dy
sp1 = 0.0
for i in range(4):
muli = 1.0 if i == 0 else dx ** i
for j in range(1, 4):
mulj = (1.0 if j == 1 else dy ** (j - 1)) * j
sp1 = sp1 + coeff[i, j] * muli * mulj
return jnp.array([sp0, sp1], dtype=spl.dtype)
[docs]
def eva2d_sd_jax(spl: Array, ix: Array, iy: Array, dx: Array, dy: Array) -> Array:
"""JAX-friendly second derivatives of spline."""
coeff = jnp.take(spl, ix, axis=2)
coeff = jnp.take(coeff, iy, axis=2)
sp0 = 0.0
for i in range(2, 4):
muli = (1.0 if i == 2 else dx ** (i - 2)) * i * (i - 1)
for j in range(4):
mulj = 1.0 if j == 0 else dy ** j
sp0 = sp0 + coeff[i, j] * muli * mulj
sp1 = 0.0
for i in range(1, 4):
muli = (1.0 if i == 1 else dx ** (i - 1)) * i
for j in range(1, 4):
mulj = (1.0 if j == 1 else dy ** (j - 1)) * j
sp1 = sp1 + coeff[i, j] * muli * mulj
sp2 = 0.0
for i in range(4):
muli = 1.0 if i == 0 else dx ** i
for j in range(2, 4):
mulj = (1.0 if j == 2 else dy ** (j - 2)) * j * (j - 1)
sp2 = sp2 + coeff[i, j] * muli * mulj
return jnp.array([sp0, sp1, sp2], dtype=spl.dtype)