Source code for neo_jax.plotting

"""Plotting helpers for NEO_JAX outputs."""

from __future__ import annotations

from typing import Iterable, Tuple

import numpy as np

from .results import NeoResults


[docs] def plot_epsilon_effective( results: NeoResults, *, ax=None, x: str = "s", label: str | None = None, ) -> Tuple[object, object]: """Plot epsilon effective vs a radial coordinate. Parameters ---------- x : {"s", "sqrt_s", "r_eff"} Radial coordinate on the x-axis. Default is ``s``. Returns (fig, ax). matplotlib is imported lazily. """ import matplotlib.pyplot as plt # lazy import if ax is None: fig, ax = plt.subplots() else: fig = ax.figure x_key = x.lower() if x_key in {"s", "psi"}: x_vals = np.asarray(results.s) xlabel = "s" elif x_key in {"sqrt_s", "sqrt(s)"}: x_vals = np.asarray(results.sqrt_s) xlabel = "sqrt(s)" elif x_key in {"r_eff", "reff"}: x_vals = np.asarray(results.r_eff) xlabel = "r_eff" else: raise ValueError(f"Unsupported x-axis '{x}'") eps_eff = np.asarray(results.epsilon_effective) ax.plot(x_vals, eps_eff, marker="o", label=label) ax.set_xlabel(xlabel) ax.set_ylabel("epsilon_effective") if label: ax.legend() return fig, ax