Source code for gwBOB.BOB_terms_jax

from jax import vmap, jit, checkpoint,jvp, config
import jax.numpy as jnp
from functools import partial

config.update("jax_enable_x64", True)


#NOTES: 
#1. THE OMEGA FUNCTIONS RETURN SMALL OMEGA, w=m*Omega.
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!THIS IS DIFFERENT THAN IN BOB_terms.py!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

#2. We don't implement phase functions here since the derivatives for the asymptotic expansion only involve A(t) and w(t)


#JAX versions

I = 1j

# Define the JAX-compatible BOB class
[docs] class JAXBOB: def __init__(self, t, Omega_0, Omega_QNM, tau, Ap, tp,m): self.t = jnp.array(t) self.Omega_0 = Omega_0 self.Omega_QNM = Omega_QNM self.tau = tau self.Ap = Ap self.tp = tp self.m = abs(m)
[docs] def convert_BOB_to_JAXBOB(BOB): #t0_tp_tau = getattr(BOB, "t0_tp_tau", None) #t0 = getattr(BOB, "t0", None) temp = JAXBOB(BOB.t, BOB.Omega_0, BOB.Omega_QNM, BOB.tau, BOB.Ap,BOB.tp,BOB.m) return temp
[docs] def BOB_amplitude_jax(t, tau, Ap, tp): ''' BOB amplitude evolution Eq.5 in https://arxiv.org/abs/1810.00040 Args: t : Time tp : Time of peak amplitude tau : Damping time; inverse of the imaginary part of the QNM frequency Ap : Peak Waveform Amplitude Returns: A(t) : Waveform amplitude at time t ''' tt = (t - tp) / tau return Ap / jnp.cosh(tt)
[docs] def BOB_news_freq_jax(t, Omega_0, Omega_QNM, tau, tp, m): ''' Waveform frequency for the news when the BOB amplitude models the news (taking t0 = -inf) Args: t : Time tp : Time of peak amplitude tau : Damping time; inverse of the imaginary part of the QNM frequency Omega_0 : Initial Condition Frequency Omega_QNM : Real part of Quasinormal mode (QNM) frequency/(mode number) m (int): Mode number Returns: omega - News waveform frequency ''' tt = (t - tp) / tau Omega_minus = Omega_QNM**2 - Omega_0**2 Omega_plus = Omega_QNM**2 + Omega_0**2 Omega2 = Omega_minus * jnp.tanh(tt) / 2. + Omega_plus / 2. return m*jnp.sqrt(jnp.maximum(Omega2, 1e-12))
[docs] def BOB_news_phase_jax(t, Omega_0, Omega_QNM, tau, tp, Phi_0, m=2): ''' Waveform phase for the news when the BOB amplitude models the news (taking t0 = -inf) Args: t : Time tp : Time of peak amplitude tau : Damping time; inverse of the imaginary part of the QNM frequency Omega_0 : Initial Condition Frequency Omega_QNM : Real part of Quasinormal mode (QNM) frequency/(mode number) Phi_0 : Initial Condition Phase (phi)/(mode number) m (int): Mode number Returns: phi - News waveform phase omega - News waveform frequency ''' omega = BOB_news_freq_jax(t, Omega_0, Omega_QNM, tau, tp, m) #news_freq_jax returns little omega Omega = omega/m Omega_minus_Q = jnp.abs(Omega - Omega_QNM) Omega_minus_0 = jnp.abs(Omega - Omega_0) # Handle the log(0) case safely by adding a small epsilon epsilon = 1e-40 Omega_minus_Q = jnp.where(Omega_minus_Q == 0, epsilon, Omega_minus_Q) Omega_minus_0 = jnp.where(Omega_minus_0 == 0, epsilon, Omega_minus_0) outer = tau / 2.0 inner1 = jnp.log(Omega + Omega_QNM) - jnp.log(Omega_minus_Q) inner2 = jnp.log(Omega + Omega_0) - jnp.log(Omega_minus_0) phase = (outer * (Omega_QNM * inner1 - Omega_0 * inner2) + Phi_0)*m return phase,omega
[docs] def BOB_psi4_freq_jax(t, Omega_0, Omega_QNM, tau, tp,m): ''' Waveform frequency for psi4 when assuming the BOB amplitude best models psi4 (taking t0 = -inf) Args: t : Time tp : Time of peak amplitude tau : Damping time; inverse of the imaginary part of the QNM frequency Omega_0 : Initial Condition Frequency Omega_QNM : Real part of Quasinormal mode (QNM) frequency/(mode number) m (int): Mode number Returns: omega - Psi4 waveform frequency ''' tt = (t - tp) / tau k = (Omega_QNM**4 - Omega_0**4) / 2.0 X = Omega_0**4 + k * (jnp.tanh(tt) + 1.0) return m*jnp.sqrt(jnp.sqrt(jnp.maximum(X, 1e-12)))
[docs] def BOB_strain_freq(t, Omega_0, Omega_QNM, tau, tp,m): ''' Waveform frequency for strain when assuming the BOB amplitude best models the strain (taking t0 = -inf) Args: t : Time tp : Time of peak amplitude tau : Damping time; inverse of the imaginary part of the QNM frequency Omega_0 : Initial Condition Frequency Omega_QNM : Real part of Quasinormal mode (QNM) frequency/(mode number) m (int): Mode number Returns: omega - Strain waveform frequency ''' tt = (t - tp) / tau Omega_ratio = Omega_0/Omega_QNM tanh_tt_m1 = jnp.tanh(tt)-1 return m*Omega_QNM*(Omega_ratio**(tanh_tt_m1/(-2.)))
[docs] def BOB_psi4_freq_finite_t0(t, Omega_0, Omega_QNM, tau, t0, tp,m): ''' Waveform frequency for psi4 when assuming the BOB amplitude best models psi4 (for finite t0) Args: t : Time t0 : Initial Condition time tp : Time of peak amplitude tau : Damping term; can also be described as 1/gamma (gamma is imaginry QNM fre) Omega_0 : Initial Condition Frequency Omega_QNM : Real part of Quasinormal mode (QNM) frequency/(mode number) m (int): Mode number Returns: omega - Psi4 waveform frequency ''' tt = (t - tp) / tau t0p = (t0-tp) / tau k_denom = 1 - jnp.tanh(t0p) k = (Omega_QNM**4 - Omega_0**4) / k_denom X = Omega_0**4 + k * (jnp.tanh(tt) - jnp.tanh(t0p)) return m*(jnp.sqrt(jnp.sqrt(jnp.maximum(X, 1e-12))))
[docs] def BOB_news_freq_finite_t0(t, Omega_0, Omega_QNM, tau, t0, tp,m): ''' Waveform frequency for the news when assuming the BOB amplitude best models the news (for finite t0) Args: t : Time t0 : Initial Condition time tp : Time of peak amplitude tau : Damping time; inverse of the imaginary part of the QNM frequency Omega_0 : Initial Condition Frequency Omega_QNM : Real part of Quasinormal mode (QNM) frequency/(mode number) m (int): Mode number Returns: omega - News waveform frequency ''' tt = (t - tp) / tau t0p = (t0-tp) / tau F_denom = 1 - jnp.tanh(t0p) F = (Omega_QNM**2 - Omega_0**2) / F_denom Omega2 = Omega_QNM**2 + F * (jnp.tanh(tt) - 1) return m*jnp.sqrt(jnp.maximum(Omega2, 1e-12))
[docs] def BOB_strain_freq_finite_t0(t, Omega_0, Omega_QNM, tau, t0, tp,m): ''' Waveform frequency for the strain when assuming the BOB amplitude best models the strain (for finite t0) Args: t : Time t0 : Initial Condition time tp : Time of peak amplitude tau : Damping time; inverse of the imaginary part of the QNM frequency Omega_0 : Initial Condition Frequency Omega_QNM : Real part of Quasinormal mode (QNM) frequency/(mode number) m (int): Mode number Returns: omega - Strain waveform frequency ''' tt = (t - tp) / tau t0p = (t0-tp) / tau Omega_ratio = Omega_0/Omega_QNM tanh_tt_m1 = jnp.tanh(tt)-1 tanh_t0p_m1 = jnp.tanh(t0p)-1 return m*Omega_QNM*(Omega_ratio**(tanh_tt_m1/tanh_t0p_m1))
[docs] def complex_scalar_derivative(g): """ Compute the derivative of a complex scalar function g(t). """ def deriv_g(t): # The Jacobian-vector product of g(t) with tangent vector 1.0 gives g'(t). _, g_prime = jvp(g, (t,), (1.0,)) return g_prime return deriv_g
[docs] @partial(jit, static_argnames=('omega_func', 'A_func','N')) def get_series_terms_ad(t, Omega_0, Omega_QNM, tau, Ap, tp, omega_func, A_func, m, N): """ Generates the raw, unsigned series terms [f₀, Df₀, D²f₀, ..., Dⁿf₀] using JAX's automatic differentiation for t0 = -inf scenarios Args: t : Time Omega_0 : Initial Condition Frequency Omega_QNM : Real part of Quasinormal mode (QNM) frequency/(mode number) tau : Damping time; inverse of the imaginary part of the QNM frequency Ap : Peak waveform amplitude tp : Time of peak amplitude omega_func: frequency function A_func: amplitude function m : Mode number N : number of terms in the series Returns: A 2D array of shape (N+1, len(t)) containing the raw terms. """ # Define the base function f₀(t) = A(t) / (i * ω(t)) def f0_func(time): A = A_func(time, tau, Ap, tp) omega = omega_func(time, Omega_0, Omega_QNM, tau, tp, m) return A / (1j * omega) # Define the operator D's pre-factor g(t) = 1 / (i * ω(t)) def g_func(time): omega = omega_func(time, Omega_0, Omega_QNM, tau, tp, m) return 1.0 / (1j * omega) # List to hold the functions that compute [f₀, Df₀, D²f₀, ...] term_funcs = [f0_func] # Recursively build the derivative functions for i in range(1, N + 1): prev_term_func = term_funcs[-1] # Use jax.checkpoint to break the computational graph and save memory # during compilation for deep derivative chains. prev_term_func_checkpointed = checkpoint(prev_term_func) deriv_of_prev = complex_scalar_derivative(prev_term_func_checkpointed) next_term_func = lambda t, g=g_func, deriv=deriv_of_prev: g(t) * deriv(t) term_funcs.append(next_term_func) # Evaluate all functions over the time array using vmap all_terms = jnp.stack([vmap(f)(t) for f in term_funcs]) return all_terms
[docs] @partial(jit, static_argnames=('omega_func', 'A_func','N')) def get_series_terms_ad_finite_t0(t, Omega_0, Omega_QNM, tau, Ap, tp, t0, omega_func, A_func, m, N): """ Generates the raw, unsigned series terms [f₀, Df₀, D²f₀, ..., Dⁿf₀] using JAX's automatic differentiation for finite t0 scenarios. Args: t : Time Omega_0 : Initial Condition Frequency Omega_QNM : Real part of Quasinormal mode (QNM) frequency/(mode number) tau : Damping time; inverse of the imaginary part of the QNM frequency Ap : Peak waveform amplitude tp : Time of peak amplitude omega_func: frequency function A_func: amplitude function m : Mode number N : number of terms in the series Returns: A 2D array of shape (N+1, len(t)) containing the raw terms. """ # Define the base function f₀(t) = A(t) / (i * ω(t)) def f0_func(time): A = A_func(time, tau, Ap, tp) omega = omega_func(time, Omega_0, Omega_QNM, tau, t0, tp, m) return A / (1j * omega) # Define the operator D's pre-factor g(t) = 1 / (i * ω(t)) def g_func(time): omega = omega_func(time, Omega_0, Omega_QNM, tau, t0, tp, m) return 1.0 / (1j * omega) # List to hold the functions that compute [f₀, Df₀, D²f₀, ...] term_funcs = [f0_func] # Recursively build the derivative functions for i in range(1, N + 1): prev_term_func = term_funcs[-1] # Use jax.checkpoint to break the computational graph and save memory # during compilation for deep derivative chains. prev_term_func_checkpointed = checkpoint(prev_term_func) deriv_of_prev = complex_scalar_derivative(prev_term_func_checkpointed) next_term_func = lambda t, g=g_func, deriv=deriv_of_prev: g(t) * deriv(t) term_funcs.append(next_term_func) # Evaluate all functions over the time array using vmap all_terms = jnp.stack([vmap(f)(t) for f in term_funcs]) return all_terms
[docs] @partial(jit) def fast_truncated_sum(all_raw_terms): """ Calculates the series sum, including the alternating signs. Args: all_raw_terms: 2D array of shape (N+1, n_times) of UNSIGNED terms. Returns: Sum of all_raw_terms with alternating signs """ N_plus_1 = all_raw_terms.shape[0] # Create the signs vector [1, -1, 1, -1, ...] signs = jnp.power(-1.0, jnp.arange(N_plus_1)).reshape(-1, 1) # Apply signs and sum along the terms axis (axis=0) series_sum = jnp.sum(all_raw_terms * signs, axis=0) return series_sum
[docs] @partial(jit, static_argnames=('omega_func', 'A_func','N')) def calculate_strain_from_news(t, Omega_0, Omega_QNM, tau, Ap, tp, omega_func, A_func, m, N): ''' Calculate the strain from the news using the series aapproximation Generates the raw, unsigned series terms [f₀, Df₀, D²f₀, ..., Dⁿf₀] using JAX's automatic differentiation for scenarios with t0 = -inf Args: t : Time Omega_0 : Initial Condition Frequency Omega_QNM : Real part of Quasinormal mode (QNM) frequency/(mode number) tau : Damping time; inverse of the imaginary part of the QNM frequency Ap : Peak waveform amplitude tp : Time of peak amplitude omega_func: frequency function A_func: amplitude function m : Mode number N : number of terms in the series Returns: list of raw series terms (excluding alternating signs) series sum ''' # 1. Generate the raw, unsigned derivative terms all_raw_terms = get_series_terms_ad(t, Omega_0, Omega_QNM, tau, Ap, tp, omega_func, A_func, m, N) sum = fast_truncated_sum(all_raw_terms) return all_raw_terms,sum
[docs] @partial(jit, static_argnames=('omega_func', 'A_func','N')) def calculate_strain_from_news_finite_t0(t, Omega_0, Omega_QNM, tau, Ap, tp, t0, omega_func, A_func, m, N): ''' Calculate the strain from the news using the series aapproximation Generates the raw, unsigned series terms [f₀, Df₀, D²f₀, ..., Dⁿf₀] using JAX's automatic differentiation for scenarios invoving finite t0 values. Args: t : Time Omega_0 : Initial Condition Frequency Omega_QNM : Real part of Quasinormal mode (QNM) frequency/(mode number) tau : Damping time; inverse of the imaginary part of the QNM frequency Ap : Peak waveform amplitude tp : Time of peak amplitude omega_func: frequency function A_func: amplitude function m : Mode number N : number of terms in the series Returns: list of raw series terms (excluding alternating signs) series sum ''' # 1. Generate the raw, unsigned derivative terms all_raw_terms = get_series_terms_ad_finite_t0(t, Omega_0, Omega_QNM, tau, Ap, tp, t0, omega_func, A_func, m, N) sum = fast_truncated_sum(all_raw_terms) return all_raw_terms,sum
def _build_symbolic_series(base_func, g_func, N_order): """ Internal helper to recursively build a list of symbolic derivative functions. Args: base_func: The starting function for the series (T₀). g_func: The function defining the pre-factor for the D operator. N_order: The truncation order for the series. Returns: A list of N+1 functions representing [T₀, T₁, ..., Tₙ]. """ term_funcs = [base_func] for _ in range(N_order): prev_func = term_funcs[-1] deriv_of_prev = complex_scalar_derivative(checkpoint(prev_func)) # The core recursive definition of the D operator next_func = lambda time, g=g_func, d=deriv_of_prev: g(time) * d(time) term_funcs.append(next_func) return term_funcs
[docs] @partial(jit, static_argnames=('omega_func', 'A_func', 'N')) def calculate_strain_from_psi4(t, Omega_0, Omega_QNM, tau, Ap, tp, omega_func, A_func, m, N): ''' Calculate the strain from psi4 using the series aapproximation Generates the raw, unsigned series terms [f₀, Df₀, D²f₀, ..., Dⁿf₀] using JAX's automatic differentiation for scenarios with t0 = -inf. Args: t : Time Omega_0 : Initial Condition Frequency Omega_QNM : Real part of Quasinormal mode (QNM) frequency/(mode number) tau : Damping time; inverse of the imaginary part of the QNM frequency Ap : Peak waveform amplitude tp : Time of peak amplitude omega_func: frequency function A_func: amplitude function m : Mode number N : number of terms in both the inner and outer series Returns: list of raw series terms (excluding alternating signs) series sum ''' M = N # --- Define the single D operator pre-factor ONCE --- def g_func(time): omega = omega_func(time, Omega_0, Omega_QNM, tau, tp, m) return 1.0 / (1j * omega) # --- Stage 1: Build the symbolic function for the News sum --- def f0_psi4_func(time): A = A_func(time, tau, Ap, tp) return A / (1j * omega_func(time, Omega_0, Omega_QNM, tau, tp, m)) news_series_term_funcs = _build_symbolic_series(f0_psi4_func, g_func, N) def news_sum_func(time): terms = jnp.stack([f(time) for f in news_series_term_funcs]) signs = jnp.power(-1.0, jnp.arange(N + 1)) return jnp.sum(terms * signs) # --- Stage 2: Build the symbolic function for the Strain sum --- def f0_strain_func(time): # The "amplitude" is the full sum from the previous stage A_news = news_sum_func(time) return A_news / (1j * omega_func(time, Omega_0, Omega_QNM, tau, tp, m)) strain_series_term_funcs = _build_symbolic_series(f0_strain_func, g_func, M) all_raw_terms_for_strain = jnp.stack( [vmap(f)(t) for f in strain_series_term_funcs] ) # --- Stage 4: Sum the final terms --- strain_sum = fast_truncated_sum(all_raw_terms_for_strain) #return strain_sum return all_raw_terms_for_strain,strain_sum
[docs] @partial(jit, static_argnames=('omega_func', 'A_func', 'N')) def calculate_strain_from_psi4_finite_t0(t, Omega_0, Omega_QNM, tau, Ap, tp, t0,omega_func, A_func, m, N): ''' Calculate the strain from psi4 using the series aapproximation Generates the raw, unsigned series terms [f₀, Df₀, D²f₀, ..., Dⁿf₀] using JAX's automatic differentiation for scenarios with finite t0. Args: t : Time Omega_0 : Initial Condition Frequency Omega_QNM : Real part of Quasinormal mode (QNM) frequency/(mode number) tau : Damping time; inverse of the imaginary part of the QNM frequency Ap : Peak waveform amplitude tp : Time of peak amplitude omega_func: frequency function A_func: amplitude function m : Mode number N : number of terms in both the inner and outer series Returns: list of raw series terms (excluding alternating signs) series sum ''' M = N # --- Define the single D operator pre-factor ONCE --- def g_func(time): omega = omega_func(time, Omega_0, Omega_QNM, tau, t0, tp, m) return 1.0 / (1j * omega) # --- Stage 1: Build the symbolic function for the News sum --- def f0_psi4_func(time): A = A_func(time, tau, Ap, tp) return A / (1j * omega_func(time, Omega_0, Omega_QNM, tau, t0, tp, m)) news_series_term_funcs = _build_symbolic_series(f0_psi4_func, g_func, N) def news_sum_func(time): terms = jnp.stack([f(time) for f in news_series_term_funcs]) signs = jnp.power(-1.0, jnp.arange(N + 1)) return jnp.sum(terms * signs) # --- Stage 2: Build the symbolic function for the Strain sum --- def f0_strain_func(time): # The "amplitude" is the full sum from the previous stage A_news = news_sum_func(time) return A_news / (1j * omega_func(time, Omega_0, Omega_QNM, tau, t0, tp, m)) strain_series_term_funcs = _build_symbolic_series(f0_strain_func, g_func, M) all_raw_terms_for_strain = jnp.stack( [vmap(f)(t) for f in strain_series_term_funcs] ) # --- Stage 4: Sum the final terms --- strain_sum = fast_truncated_sum(all_raw_terms_for_strain) #return strain_sum return all_raw_terms_for_strain,strain_sum