Source code for gwBOB.mismatch_utils

#pyright: reportUnreachable=false
#JAX implemented mismatch search

#Notes:
#This code is only meant for merger-ringdown searches. It performs a time domain integration since we want to calculate the mismatch over a fixed time window.

#When we pass in EOB/surrogate data, we align at peak with the NR data beforehand, so the ideal time shift should be close to 0

#Instead of a grid based phase search, we find the best phase by maximizing the overlap
#https://journals.aps.org/prd/pdf/10.1103/PhysRevD.85.122006 section 4, text after eq 4.1


from functools import partial
from jax import jit,vmap
import jax.numpy as jnp
from jax import debug
#jax.config.update("jax_log_compiles", True)

[docs] @partial(jit) def time_shift(h_complex, t, t_shift): """ JAX compatible time shift function. Shifts the time series by a given amount. Parameters ---------- h_complex : complex array The complex-valued time series to be shifted. t : array The time array. t_shift : float The amount to shift the time series by. Returns ------- h_shifted : complex array The shifted time series. """ shifted_time_grid = t - t_shift h_shifted_real = jnp.interp(shifted_time_grid, t, h_complex.real) h_shifted_imag = jnp.interp(shifted_time_grid, t, h_complex.imag) return h_shifted_real + 1j * h_shifted_imag
[docs] @partial(jit, static_argnames=('integration_points',)) def mismatch_trapz( h1_padded, t1_padded, # Original, unshifted Model data h2_padded, t2_padded, # NR data t_peak_nr, t0_relative, tf_relative, integration_points ): """ JAX compatible mismatch function. Calculates the mismatch between two time series using the trapz integration method. Parameters ---------- h1_padded : complex array The complex-valued time series of the model data. t1_padded : array The time array of the model data. h2_padded : complex array The complex-valued time series of the NR data. t2_padded : array The time array of the NR data. t_peak_nr : float The peak time of the NR data. t0_relative : float The relative start time of the integration window. tf_relative : float The relative end time of the integration window. integration_points : int The number of integration points. Returns ------- mismatch : float The mismatch between the two time series. """ t_start_abs = t_peak_nr + t0_relative t_end_abs = t_peak_nr + tf_relative t_integ = jnp.linspace(t_start_abs, t_end_abs, integration_points) h1_integ = jnp.interp(t_integ, t1_padded, h1_padded.real, left=0.0, right=0.0) + \ 1j * jnp.interp(t_integ, t1_padded, h1_padded.imag, left=0.0, right=0.0) h2_integ = jnp.interp(t_integ, t2_padded, h2_padded.real, left=0.0, right=0.0) + \ 1j * jnp.interp(t_integ, t2_padded, h2_padded.imag, left=0.0, right=0.0) # Numerator numerator_integrand = jnp.conj(h1_integ) * h2_integ numerator_integral = jnp.trapezoid(numerator_integrand, t_integ) # Denominators denom1_integrand = jnp.real(jnp.conj(h1_integ) * h1_integ) denom2_integrand = jnp.real(jnp.conj(h2_integ) * h2_integ) denom1_sq = jnp.trapezoid(denom1_integrand, t_integ) denom2_sq = jnp.trapezoid(denom2_integrand, t_integ) denominator1 = jnp.sqrt(denom1_sq) denominator2 = jnp.sqrt(denom2_sq) epsilon = 1e-20 maximized_overlap = jnp.abs(numerator_integral) / (denominator1 * denominator2 + epsilon) best_phi0 = -jnp.angle(numerator_integral) mismatch = 1.0 - maximized_overlap return mismatch
#uncomment if we want to use cubic spline integration # @partial(jit, static_argnames=('integration_points')) # def mismatch_interpax( # h1_padded, t1_padded, #model data # h2_padded, t2_padded, #nr data # t_peak_nr, # t0_relative, tf_relative,integration_points): # t_start_abs = t_peak_nr + t0_relative # t_end_abs = t_peak_nr + tf_relative # t_integ = jnp.linspace(t_start_abs,t_end_abs,integration_points) # # Resample h2 onto t1's grid. # # left=0.0, right=0.0 ensures padded regions outside t2's domain become zero. # h1_common = jnp.interp(t_integ,t1_padded,h1_padded, left=0.0, right=0.0) # h2_common = jnp.interp(t_integ, t2_padded, h2_padded, left=0.0, right=0.0) # h1_integ = jnp.interp(t_integ, t1_padded, h1_padded.real, left=0.0, right=0.0) + \ # 1j * jnp.interp(t_integ, t1_padded, h1_padded.imag, left=0.0, right=0.0) # h2_integ = jnp.interp(t_integ, t2_padded, h2_padded.real, left=0.0, right=0.0) + \ # 1j * jnp.interp(t_integ, t2_padded, h2_padded.imag, left=0.0, right=0.0) # numerator_integrand = jnp.conj(h1_integ) * h2_integ # denom1_integrand = jnp.real(jnp.conj(h1_integ) * h1_integ) # denom2_integrand = jnp.real(jnp.conj(h2_integ) * h2_integ) # numerator_integral = interpax.CubicSpline( # x = t_integ, # y = numerator_integrand, # check=False # ).integrate(t_start_abs, t_end_abs) # denom1_sq = interpax.CubicSpline( # x = t_integ, # y = denom1_integrand, # check=False # ).integrate(t_start_abs, t_end_abs) # denom2_sq = interpax.CubicSpline( # x = t_integ, # y = denom2_integrand, # check=False # ).integrate(t_start_abs, t_end_abs) # denominator1 = jnp.sqrt(jnp.real(denom1_sq)) # denominator2 = jnp.sqrt(jnp.real(denom2_sq)) # epsilon = 1e-20 # #we take the absolute value of numerator_integral because that corresponds to the maximum overlap/ideal phase shift # maximized_overlap = jnp.abs(numerator_integral) / (denominator1 * denominator2 + epsilon) # #best_phi0 = -jnp.angle(numerator_integral) # mismatch = 1.0 - maximized_overlap # return mismatch
[docs] @partial(jit, static_argnames=('t0', 'tf', 'coarse_window', 'coarse_t_num', 'fine_window', 'fine_t_num','integration_points')) def find_best_mismatch_padded( padded_t_model, padded_h_model, padded_t_nr, padded_h_nr, nr_peak_time_batch, t0, tf, coarse_window, coarse_t_num, fine_window, fine_t_num, integration_points ): ''' JAX compatible mismatch search function. Finds the best time shift between two time series using the trapz integration method. Parameters ---------- padded_t_model : array The time array of the model data. padded_h_model : complex array The complex-valued time series of the model data. padded_t_nr : array The time array of the NR data. padded_h_nr : complex array The complex-valued time series of the NR data. nr_peak_time_batch : array The peak time of the NR data. t0 : float The relative start time of the integration window. tf : float The relative end time of the integration window. coarse_window : float The coarse window size. coarse_t_num : int The number of coarse integration points. fine_window : float The fine window size. fine_t_num : int The number of fine integration points. integration_points : int The number of integration points. ''' def find_best_for_one_waveform(t_m, h_m, t_n, h_n, nr_peak): t_range_1 = jnp.linspace(-coarse_window, coarse_window, coarse_t_num) @vmap def do_search(t_shift): h_m_shifted = time_shift(h_m, t_m, t_shift) return mismatch_trapz( h_m_shifted, t_m, h_n, t_n, nr_peak, t0, tf,integration_points ) mismatches_1 = do_search(t_range_1) min_idx_1 = jnp.argmin(mismatches_1) mismatch_1 = mismatches_1[min_idx_1] t_shift_1 = t_range_1[min_idx_1] t_range_2 = jnp.linspace( t_shift_1 - fine_window, t_shift_1 + fine_window, fine_t_num ) mismatches_2 = do_search(t_range_2) min_idx_2 = jnp.argmin(mismatches_2) mismatch_2 = mismatches_2[min_idx_2] t_shift_2 = t_range_2[min_idx_2] #debug.print("t_shift_1 = {x}",x=t_shift_1) #debug.print("t_shift_2 = {x}",x=t_shift_2) is_fine_search_better = mismatch_2 < mismatch_1 final_mismatch = jnp.where(is_fine_search_better, mismatch_2, mismatch_1) #final_t_shift = jnp.where(is_fine_search_better, t_shift_2, t_shift_1) #debug.print("final_t_shift = {x}",x=final_t_shift) return final_mismatch # --- Apply the entire 2-stage search to the batch of waveforms --- return vmap(find_best_for_one_waveform)( padded_t_model, padded_h_model, padded_t_nr, padded_h_nr, nr_peak_time_batch )