Source code for swiftcl.interpolate

import jax
import jax.numpy as jnp
from fftlog.fftlog import FFTLog
from interpax import interp1d


[docs] class interp(object): """ A class to compute integrals of Bessel functions multiplied by a function using FFTLog-based techniques. This class supports evaluation of three integral variants: - Standard spherical Bessel integrals (`int_jl`) - Integrals with additional `k^-2` scaling (`int_k`) - Second-derivative Bessel integrals (`int_ddjl`) Args: l (array_like): Values of l to compute. N (int): Maximum number of points for the FFTLog of the window function. x (array_like): Points where the input function is evaluated. f (array_like): 2D array of shape (len(x), len(a)) representing the function multiplying the Bessel function. a (array_like): Values multiplying the Bessel function argument, i.e., in j_l(ax). b (float): Bias for the FFTLog. g_r (array_like): Precomputed Gamma coefficients. xmin (float): Minimum value of `x` for FFTLog. xmax (float): Maximum value of `x` for FFTLog. N_interp (int): Number of points to interpolate across `a`. Attributes: l (ndarray): Spherical harmonic degrees. N (int): Maximum FFTLog points. x (ndarray): Evaluation grid for the function. f (ndarray): Function values to be integrated. a (ndarray): Multiplier for Bessel argument. b (float): Bias parameter for FFTLog. g_r (ndarray): Integration matrix. xmin (float): Lower bound of `x`. xmax (float): Upper bound of `x`. N_interp (int): Interpolation resolution for `a`. wind (float): Width of smoothing function (default 0.2). fft (FFTLog): FFTLog instance used for coefficient computation. power_fft (ndarray): FFTLog exponents. step (int): Step size between thick samples in `a`. a_thick (ndarray): Thinned `a` sampling used for interpolation. Methods: int_jl(): Compute the integral ∫ f(x) * j_l(ax) dx. int_k(): Compute the integral ∫ f(x) * j_l(ax) * k^2 dx. int_ddjl(): Compute the integral with a second-derivative weighting of the Bessel function. """ def __init__(self, **kwargs): self.l = kwargs["l"] self.N = kwargs["N"] self.x = kwargs["x"] self.f = kwargs["f"] self.a = kwargs["a"] self.b = kwargs["b"] self.g_r = kwargs["g_r"] self.xmin = kwargs["xmin"] self.xmax = kwargs["xmax"] self.N_interp = kwargs["N_interp"] self.wind = 0.2 fftsettings = dict( Nmax=self.N, xmin=self.xmin, xmax=self.xmax, bias=self.b, window=self.wind ) self.fft = FFTLog(**fftsettings) self.power_fft = self.fft.Pow self.step = int(len(self.a) / self.N_interp) self.a_thick = self.a[:: self.step]
[docs] def int_jl(self): """ Compute the integral of the form: ∫ f(x) * j_l(ax) dx using FFTLog coefficients and interpolation. Returns: ndarray: Resulting 2D array of shape (len(l), len(a)), the integral values. """ f_thick = jnp.swapaxes(self.f[:, :: self.step], 0, 1) w_thick = self.fft.Coef(xin=self.x, f=f_thick, extrap="padding") # k_thick, N w = interp1d(self.a, self.a_thick, w_thick, extrap=True) # k, N k_pow = jax.vmap(lambda p: self.a ** (-1 - p) * 2 ** (-1 + p), 0, 0)( self.power_fft ) return jnp.real( jnp.sqrt(jnp.pi) * jnp.einsum("lp, kp, pk -> lk", self.g_r, w, k_pow) )
[docs] def int_k(self): """ Compute the integral of the form: ∫ f(x) * j_l(ax) * a^-2 dx using FFTLog coefficients and interpolation. Returns: ndarray: Resulting 2D array of shape (len(l), len(a)), the integral values. """ f_thick = jnp.swapaxes(self.f[:, :: self.step], 0, 1) w_thick = self.fft.Coef(xin=self.x, f=f_thick, extrap="padding") # k_thick, N w = interp1d(self.a, self.a_thick, w_thick, extrap=True) # k, N k_pow = jax.vmap(lambda p: self.a ** (-1 - p) * 2 ** (-1 + p), 0, 0)( self.power_fft - 2 ) return jnp.real( jnp.sqrt(jnp.pi) * jnp.einsum("lp, kp, pk -> lk", self.g_r, w, k_pow) )
[docs] def int_ddjl(self): """ Compute the integral with a second-derivative: ∫ f(x) * d²/dx²[j_l(ax)] dx using FFTLog coefficients and interpolation. Returns: ndarray: Resulting 2D array of shape (len(l), len(a)), the integral values. """ fftsettings = dict( Nmax=self.N, xmin=self.xmin, xmax=self.xmax, bias=self.b + 0.91, window=self.wind, ) self.fft = FFTLog(**fftsettings) self.power_fft = self.fft.Pow f_thick = jnp.swapaxes(self.f[:, :: self.step], 0, 1) w_thick = self.fft.Coef(xin=self.x, f=f_thick, extrap="padding") # k_thick, N w = interp1d(self.a, self.a_thick, w_thick, extrap=True) # k, N k_pow = jax.vmap( lambda p: self.a ** (-1 - p) * 2 ** (-3 + p) * (-1 + p) * p, 0, 0 )(self.power_fft) return jnp.real( jnp.sqrt(jnp.pi) * jnp.einsum("lp, kp, pk -> lk", self.g_r, w, k_pow) )