Source code for HyRex.hydrogen

import numpy as np

import jax.numpy as jnp
from jax import config, lax, grad
import equinox as eqx

from diffrax import diffeqsolve, SaveAt, ODETerm, Tsit5, Kvaerno3, PIDController, ForwardMode, Event

from . import cosmology
from .cosmology import mH, c, hbar, kB
from . import recomb_functions
from .array_with_padding import array_with_padding
config.update("jax_enable_x64", True)

import os
file_dir = os.path.dirname(__file__)

[docs] class hydrogen_model(eqx.Module): """ Hydrogen recombination model implementation. Computes hydrogen ionization fraction evolution through multiple phases: Saha equilibrium, post-Saha expansion, HYREC-2 EMLA with two-photon processes, and late-time EMLA-only evolution. Attributes: ----------- integration_spacing : float Step size in log scale factor for integration swift : array Tabulated SWIFT correction function for recombination rates concrete_axis_size : array Pre-allocated array for post-Saha phase integration xe_4He : array Helium ionization fraction from previous calculation lna_4He : array Log scale factor grid from helium recombination last_4He_lna : float Final log scale factor from helium phase twog_redshift : float Redshift at which two-photon processes become negligible lna_end : float Final log scale factor for hydrogen evolution Methods: -------- get_hydrogen_history : Compute full hydrogen recombination history (units: dimensionless) Saha_equilibrium : Compute Saha equilibrium phase (units: dimensionless) post_Saha_expansion : Compute post-Saha expansion phase (units: dimensionless) solve_emla_twophoton : Solve EMLA with two-photon processes (units: dimensionless) solve_emla : Solve EMLA-only evolution (units: dimensionless) xe_derivative_twophoton : Compute ionization fraction derivative with two-photon (units: dimensionless) xe_tm_derivative : Compute coupled xe and Tm derivatives (units: dimensionless, eV) dxe_dlna_twophoton : Compute two-photon recombination rate (units: dimensionless) get_current_correction_func : Interpolate SWIFT correction function (units: dimensionless) steady_state_equations : Set up steady-state level population equations (units: dimensionless) Lyn_esc_rate : Compute Lyman escape rate (units: s^{-1}) """ integration_spacing : jnp.float64 swift : jnp.array concrete_axis_size : jnp.array xe_4He : jnp.array lna_4He : jnp.array last_4He_lna : jnp.float64 twog_redshift : jnp.float64 lna_end : jnp.float64 def __init__(self,xe_4He,lna_4He,lna_end,last_4He_lna,twog_redshift,integration_spacing = 5.0e-4, Nsteps=800,swift = jnp.array(np.loadtxt(file_dir+"/tabs/fit_swift.dat"))): """ Initialize hydrogen recombination model. Parameters: ----------- xe_4He : array_with_padding Helium ionization fraction from previous calculation lna_4He : array_with_padding Log scale factor array from helium calculation lna_end : array Ending log scale factor last_4He_lna : float Final log scale factor from helium recombination integration_spacing : float, optional Step size for integration (default: 5.0e-4) Nsteps : int, optional Maximum number of integration steps (default: 800) swift : array, optional SWIFT correction function tabulation """ self.integration_spacing = integration_spacing self.swift = swift # Define time axes self.lna_end = lna_end self.concrete_axis_size = jnp.zeros(Nsteps) # pull in helium self.xe_4He = xe_4He self.lna_4He = lna_4He self.last_4He_lna = last_4He_lna self.twog_redshift = twog_redshift def __call__(self, h, omega_b, omega_cdm, Neff, YHe, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): """ Compute hydrogen recombination history. Parameters: ----------- h : float Hubble parameter omega_b : float The baryon density Omega_b h^2 omega_cdm : float The density of Cold Dark Matter Omega_cdm h^2 Neff : float Effective number of neutrinos YHe : float Helium fraction rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional Absolute tolerance for ODE solver (default: 1e-9) solver : diffrax.Solver, optional ODE solver instance (default: Kvaerno3()) max_steps : int, optional Maximum solver steps (default: 1024) Returns: -------- tuple (xe_full, lna_full, Tm, lna_Tm) - ionization fraction, log scale factor, matter temperature, and temperature grid """ return self.get_hydrogen_history(h, omega_b, omega_cdm, Neff, YHe, rtol, atol, solver, max_steps)
[docs] def get_hydrogen_history(self, h, omega_b, omega_cdm, Neff, YHe, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): """ Compute complete hydrogen recombination history through all phases. Sequentially computes hydrogen ionization fraction through post-Saha expansion, HYREC-2 EMLA with two-photon processes, and late-time EMLA-only evolution phases. Parameters: ----------- h : float Hubble parameter omega_b : float The baryon density Omega_b h^2 omega_cdm : float The density of Cold Dark Matter Omega_cdm h^2 Neff : float Effective number of neutrinos YHe : float Helium fraction rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional Absolute tolerance for ODE solver (default: 1e-9) solver : diffrax.Solver, optional ODE solver instance (default: Kvaerno3()) max_steps : int, optional Maximum solver steps (default: 1024) Returns: -------- tuple (xe_full, lna_full, Tm, lna_Tm) containing ionization fraction evolution, log scale factor grid, matter temperature, and temperature grid """ # Start computing xe at different phases ################## move to H ################## ### POST SAHA EXPANSION PHASE ### xe_output_post, lna_output_post = self.post_Saha_expansion(self.last_4He_lna+self.integration_spacing, h, omega_b, omega_cdm, Neff, YHe) xe_4He_and_post = self.xe_4He.concat(array_with_padding(xe_output_post)) lna_4He_and_post = self.lna_4He.concat(array_with_padding(lna_output_post)) ### END OF POST SAHA EXPANSION PHASE ### ### HYREC2 EMLA + FULL TWO PHOTON PHASE ### xe_output_2g, lna_output_2g = self.solve_emla_twophoton(lna_4He_and_post.lastval, -jnp.log(self.twog_redshift), xe_4He_and_post.lastval, h, omega_b, omega_cdm, Neff, YHe, rtol, atol, solver, max_steps) xe_4He_post_2g = xe_4He_and_post.concat(array_with_padding(xe_output_2g)) lna_4He_post_2g = lna_4He_and_post.concat(array_with_padding(lna_output_2g)) ### END HYREC2 EMLA + FULL TWO PHOTON PHASE ### ### HYREC2 EMLA ONLY PHASE ### # do not input max_steps xe_output_late, Tm_output_late, lna_output_late = self.solve_emla(lna_4He_post_2g.lastval, xe_4He_post_2g.lastval, h, omega_b, omega_cdm, Neff, YHe, rtol, atol, solver) lna_Tm = array_with_padding(lna_output_late) Tm = array_with_padding(Tm_output_late) xe_4He_post_2g_late = xe_4He_post_2g.concat(array_with_padding(xe_output_late)) lna_4He_post_2g_late = lna_4He_post_2g.concat(lna_Tm) ### END OF HYREC2 EMLA ONLY PHASE ### ### Begin TLA phase ### xe_output_TLA, Tm_output_TLA, lna_output_TLA = self.solve_TLA(lna_4He_post_2g_late.lastval, xe_4He_post_2g_late.lastval, Tm.lastval, h, omega_b, omega_cdm, Neff, YHe) xe_all = xe_4He_post_2g_late.concat(array_with_padding(xe_output_TLA)) lna_all = lna_4He_post_2g_late.concat(array_with_padding(lna_output_TLA)) Tm_all = Tm.concat(array_with_padding(Tm_output_TLA)) lna_Tm_all = lna_Tm.concat(array_with_padding(lna_output_TLA)) ### End TLA ### return (xe_all, lna_all, Tm_all, lna_Tm_all)
[docs] def post_Saha_expansion(self, starting_lna, h, omega_b, omega_cdm, Neff, YHe, threshold=1e-5): """ Compute post-Saha expansion phase with two-photon corrections. Calculates ionization fraction including two-photon processes as perturbative corrections to Saha equilibrium until deviations exceed threshold. Parameters: ----------- starting_lna : float Initial log scale factor h : float Hubble parameter omega_b : float The baryon density Omega_b h^2 omega_cdm : float The density of Cold Dark Matter Omega_cdm h^2 Neff : float Effective number of neutrinos YHe : float Helium fraction threshold : float, optional Threshold for deviation from Saha (default: 1e-5) Returns: -------- tuple (xe_output, lna_output) - ionization fraction and log scale factor arrays """ # Calculate omega_rad today using input Neff. omega_rad = cosmology.omega_rad0(Neff) # Initial conditions z0_local = jnp.exp(-starting_lna) - 1. TCMB = cosmology.TCMB(z0_local) nH = cosmology.nH(z0_local, omega_b, YHe) xe0, _ = recomb_functions.xe_Saha(TCMB, nH) # Assume initially in Saha equilibrium # Pre-allocate xe_output xe_output = jnp.ones_like(self.concrete_axis_size)*jnp.inf lna_output = jnp.ones_like(self.concrete_axis_size)*jnp.inf iz = 0 xe = xe0 stop = False def compute_xe(carry): xe_output, lna_output, xe, iz, stop = carry lna = starting_lna + iz*self.integration_spacing z = jnp.exp(-lna) - 1. # Cosmological parameters TCMB = cosmology.TCMB(z) nH = cosmology.nH(z, omega_b, YHe) H = cosmology.Hubble(z, h, omega_b, omega_cdm, omega_rad) # Saha equilibrium for xe xe_Saha, s = recomb_functions.xe_Saha(TCMB, nH) dxe_Saha_dlna = -(recomb_functions.rydberg / TCMB - 3./2.) * xe_Saha**2 / (2. * xe_Saha + s) # Compute xe using two-photon processes grad_dxedlna_func = grad(self.dxe_dlna_twophoton, argnums=0) grad_dxedlna = grad_dxedlna_func(xe_Saha, TCMB, TCMB, H, nH, 0.0) xe = xe_Saha + dxe_Saha_dlna / grad_dxedlna # Store current xe value in the output array xe_output = xe_output.at[iz].set(xe) lna_output = lna_output.at[iz].set(lna) # Check difference diff = jnp.abs(xe_Saha - xe) stop = diff > threshold # Stop when diff < threshold # Increment index iz = iz + 1 return (xe_output, lna_output, xe, iz, stop) def stop_condition(state): _, _, _, iz, stop = state return (iz < self.concrete_axis_size.size) & (~stop) # Continue until stop condition is met or we run out of space # Initial state: (xe_output, xe, iz, stop flag) initial_state = (xe_output, lna_output, xe, iz, stop) # Run the while loop until the stop condition is met final_state = lax.while_loop(stop_condition, compute_xe, initial_state) # Unpack the final state xe_output_final, lna_output_final, _, _, _ = final_state # Return the electron fraction array and the stopping `lna` value return xe_output_final, lna_output_final
[docs] def xe_derivative_twophoton(self, lna, xe, args): """ Compute ionization fraction derivative including two-photon processes. Derivative function for hydrogen ionization fraction evolution including two-photon transitions and correction functions. Parameters: ----------- lna : float Log scale factor xe : float Current ionization fraction args : tuple h, omega_b, omega_cdm, Neff, YHe; the Hubble parameter, the baryon denisty Omega_b h^2, the CDM density Omega_cdm h^2, the effecgive number of neutrinos, and the helium fraction Returns: -------- float Time derivative dxe/dlna (units: dimensionless) """ h, omega_b, omega_cdm, Neff, YHe = args omega_rad = cosmology.omega_rad0(Neff) z = 1. / jnp.exp(lna) - 1. x1s = 1. - xe # fraction of neutral hydrogen TCMB = cosmology.TCMB(z) # eV nH = cosmology.nH(z, omega_b, YHe) # hydrogen number density, 1/cm^3 H = cosmology.Hubble(z, h, omega_b, omega_cdm, omega_rad) # Hubble parameter, 1/s GammaC = recomb_functions.Gamma_compton(xe, TCMB, YHe) # Compton scattering rate, 1/s Tm = TCMB * (1.-H/GammaC) Delta = self.get_current_correction_func(TCMB, omega_b, omega_cdm, YHe, Neff) dxedlna = self.dxe_dlna_twophoton(xe, TCMB, Tm, H, nH, Delta) return dxedlna
[docs] def solve_emla_twophoton(self, lna_axis_init, lna_axis_final, xe0, h, omega_b, omega_cdm, Neff, YHe, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): """ Solve HYREC-2 EMLA evolution with two-photon processes. Integrates hydrogen recombination including effective multilevel atom approximation with two-photon transitions and correction functions. Parameters: ----------- lna_axis_init : float Initial log scale factor lna_axis_final : float Final log scale factor xe0 : float Initial ionization fraction h : float Hubble parameter omega_b : float The baryon density Omega_b h^2 omega_cdm : float The density of Cold Dark Matter Omega_cdm h^2 Neff : float Effective number of neutrinos YHe : float Helium fraction rtol : float, optional Relative tolerance (default: 1e-6) atol : float, optional Absolute tolerance (default: 1e-9) solver : diffrax.Solver, optional ODE solver (default: Kvaerno3()) max_steps : int, optional Maximum steps (default: 1024) Returns: -------- tuple (xe_output, lna_output) - ionization fraction and log scale factor arrays """ # Initial conditions TCMB_init = cosmology.TCMB(jnp.exp(-lna_axis_init ) - 1.) # Initial CMB temperature initial_state = xe0 term = ODETerm(self.xe_derivative_twophoton) t0 = lna_axis_init # t1 = lna_axis_final t1 = jnp.inf # don't want to double count the boundary lna, so start saving after one step t_arr = jnp.linspace(t0+self.integration_spacing, t0+2*max_steps*self.integration_spacing, 2*max_steps) save_at = SaveAt(ts=t_arr) adjoint=ForwardMode() def lna_check(t, y, args, **kwargs): return t > lna_axis_final sol = diffeqsolve( term, solver, t0=t0, t1=t1, dt0=1e-3, y0=initial_state, args=(h, omega_b, omega_cdm, Neff, YHe), stepsize_controller=PIDController(rtol, atol), saveat=save_at, event=Event(lna_check), adjoint=adjoint ) xe_output = sol.ys lna_output = sol.ts return xe_output, lna_output
[docs] def xe_tm_derivative(self, lna, state, args): """ Compute coupled derivatives for ionization fraction and matter temperature. Derivative function for simultaneous evolution of hydrogen ionization fraction and matter temperature including Compton heating/cooling. Parameters: ----------- lna : float Log scale factor state : array Current state [xe, Tm] args : tuple h, omega_b, omega_cdm, Neff, YHe; the Hubble parameter, the baryon denisty Omega_b h^2, the CDM density Omega_cdm h^2, the effecgive number of neutrinos, and the helium fraction Returns: -------- array Time derivatives [dxe/dlna, dTm/dlna] (units: dimensionless, eV) """ xe, Tm = state h, omega_b, omega_cdm, Neff, YHe = args omega_rad = cosmology.omega_rad0(Neff) z = 1. / jnp.exp(lna) - 1. # redshift z TCMB = cosmology.TCMB(z) # eV nH = cosmology.nH(z, omega_b, YHe) # hydrogen number density, 1/cm^3 H = cosmology.Hubble(z, h, omega_b, omega_cdm, omega_rad) # Hubble parameter, 1/s GammaC = recomb_functions.Gamma_compton(xe, TCMB, YHe) # Compton scattering rate, 1/s Delta = 0.0 dxedlna = self.dxe_dlna_twophoton(xe, TCMB, Tm, H, nH, Delta) dTmdlna = (-2 * H * Tm + GammaC * (TCMB - Tm)) / H return jnp.array([dxedlna, dTmdlna])
[docs] def solve_emla(self, lna0, xe0, h, omega_b, omega_cdm, Neff, YHe,rtol=1e-7, atol=1e-9,solver=Tsit5(),max_steps=4096): """ Solve late-time EMLA evolution without two-photon processes. Integrates hydrogen recombination using effective multilevel atom approximation for late times when two-photon processes are negligible. Parameters: ----------- lna0 : float Log scale factor at which initial xe is given xe0 : float Initial ionization fraction h : float Hubble parameter omega_b : float The baryon density Omega_b h^2 omega_cdm : float The density of Cold Dark Matter Omega_cdm h^2 Neff : float Effective number of neutrinos YHe : float Helium fraction rtol : float, optional Relative tolerance (default: 1e-7) atol : float, optional Absolute tolerance (default: 1e-9) solver : diffrax.Solver, optional ODE solver (default: Tsit5()) max_steps : int, optional Maximum steps (default: 4096) Returns: -------- tuple (xe_output, Tm_output, lna_output) - ionization fraction, matter temperature, and lna arrays """ omega_rad = cosmology.omega_rad0(Neff) t0 = lna0 t1 = jnp.inf # need to go at least twice max_steps to make sure we catch the t1 we actually want t_arr = jnp.linspace(t0+self.integration_spacing, t0+2*max_steps*self.integration_spacing, 2*max_steps) save_at = SaveAt(ts=t_arr) TCMB_init = cosmology.TCMB(jnp.exp(-t0) - 1.) Tm0 = TCMB_init * (1.-cosmology.Hubble(1/jnp.exp(t0) - 1, h, omega_b, omega_cdm, omega_rad)/recomb_functions.Gamma_compton(xe0, TCMB_init, YHe)) initial_state = jnp.array([xe0, Tm0]) term = ODETerm(self.xe_tm_derivative) adjoint=ForwardMode() def temperature_check(t, y, args, **kwargs): lna = t _, Tm = y z = jnp.exp(-lna) - 1 TCMB = cosmology.TCMB(z) TR_MIN = recomb_functions.TR_MIN # Minimum Tcmb in eV T_RATIO_MIN = recomb_functions.T_RATIO_MIN # Minimum Tratio ratio = jnp.minimum(Tm / TCMB, TCMB / Tm) return jnp.logical_or(TCMB < TR_MIN, ratio < T_RATIO_MIN) # stop when true event = Event(temperature_check) sol = diffeqsolve( term, solver, t0=t0, t1=t1, dt0=1e-3, y0=initial_state, args=(h, omega_b, omega_cdm, Neff, YHe), stepsize_controller=PIDController(rtol, atol),saveat=save_at, adjoint=adjoint, max_steps=max_steps, event = event ) xe_output = sol.ys[:, 0] Tm_output = sol.ys[:, 1] return xe_output, Tm_output, sol.ts
[docs] def dxe_dlna_twophoton(self, xe, TCMB, Tm, H, nH, Delta): """ Compute two-photon recombination rate. Calculates ionization fraction evolution rate including two-photon transitions using effective multilevel atom approximation. Parameters: ----------- xe : float Current ionization fraction TCMB : float CMB temperature (units: eV) Tm : float Matter temperature (units: eV) H : float Hubble parameter (units: s^{-1}) nH : float Hydrogen number density (units: cm^{-3}) Delta : float Correction function value Returns: -------- float Recombination rate dxe/dlna (units: dimensionless) """ x1s = 1.-xe # Interpolate transition rates A2s, A2p, B2s, B2p, _, _, R2p2s, R2s2p = recomb_functions.effective_coefficients(TCMB, Tm, H, nH, x1s) # Compute the matrix and source vector forms of steady state equations, then solve the linear system for real and virtual populations. T, S = self.steady_state_equations(xe, H, nH, TCMB, A2s, A2p, B2s, B2p, R2p2s, R2s2p, Delta) X = jnp.linalg.solve(T, S) x2s = X[0] x2p = X[1] return (x2s*B2s + x2p*B2p - xe**2*nH*(A2s+A2p)) / H
[docs] def get_current_correction_func(self, TCMB, omega_b, omega_cdm, YHe, Neff): """ Interpolate correction function for current cosmology. Interpolates correction function and applies cosmological parameter derivatives at current CMB temperature for accurate recombination rates. Parameters: ----------- TCMB : float CMB temperature (units: eV) h : float Hubble parameter omega_b : float The baryon density Omega_b h^2 omega_cdm : float The density of Cold Dark Matter Omega_cdm h^2 Neff : float Effective number of neutrinos YHe : float Helium fraction Returns: -------- float Correction function value (units: dimensionless) """ # Fiducial cosmology values at which the correction functions were tabulated. omega_H_fid = 0.01689 omega_cb_fid = 0.14175 Neff_fid = 3.046 # For the user inputed cosmology currently scanned over. omega_H = omega_b*(1-YHe) omega_cb = omega_b + omega_cdm Delta = jnp.interp(TCMB, kB*self.swift[:, 0], self.swift[:, 1]) dDelta_domcb = jnp.interp(TCMB, kB*self.swift[:, 0], self.swift[:, 2]) dDelta_domH = jnp.interp(TCMB, kB*self.swift[:, 0], self.swift[:, 3]) dDelta_dNeff = jnp.interp(TCMB, kB*self.swift[:, 0], self.swift[:, 4]) return Delta + (omega_cb-omega_cb_fid)*dDelta_domcb + (omega_H-omega_H_fid)*dDelta_domH + (Neff-Neff_fid)*dDelta_dNeff
[docs] def steady_state_equations(self, xe, H, nH, TCMB, A2s, A2p, B2s, B2p, R2p2s, R2s2p, Delta): """ Set up steady-state level population equations. Constructs matrix equation for hydrogen level populations in steady-state approximation for 2s and 2p levels. Parameters: ----------- xe : float Current ionization fraction H : float Hubble parameter (units: s^{-1}) nH : float Hydrogen number density (units: cm^{-3}) TCMB : float CMB temperature (units: eV) A2s : float 2s recombination coefficient (units: cm^3 s^{-1}) A2p : float 2p recombination coefficient (units: cm^3 s^{-1}) B2s : float 2s photoionization coefficient (units: s^{-1}) B2p : float 2p photoionization coefficient (units: s^{-1}) R2p2s : float 2p→2s transition rate (units: s^{-1}) R2s2p : float 2s→2p transition rate (units: s^{-1}) Delta : float Correction function value Returns: -------- tuple (T, S) - transition matrix and source vector for level populations """ T = jnp.zeros((2, 2), dtype="float64") S = jnp.zeros(2, dtype="float64") x1s = 1.-xe # Recombined hydrogen fraction. # List of transition rates needed RLya = self.Lyn_esc_rate(2, H, nH, x1s) # Lyman-alpha escape rate. R2s1s = 8.2206 # Two-photon transition rate from 2s to 1s. R2p1s = RLya / (1.+Delta) # Two-photon transition rate from 2p to 1s, with HYREC-2 fitting correction function. R1s2s = jnp.exp(-recomb_functions.E21/TCMB)*R2s1s R1s2p = 3.*jnp.exp(-recomb_functions.E21/TCMB)*R2p1s # Upper 2x2 part of T matrix. T = T.at[0, 0].set(B2s+R2s2p+R2s1s) T = T.at[0, 1].set(-R2p2s) T = T.at[1, 0].set(-R2s2p) T = T.at[1, 1].set(B2p+R2p2s+R2p1s) # First 2 entries of source vector elements S = S.at[0].set(xe**2*nH*A2s+x1s*R1s2s) S = S.at[1].set(xe**2*nH*A2p+x1s*R1s2p) return (T, S)
[docs] def Lyn_esc_rate(self, n, H, nH, x1s): """ Computes the Lyman-n escape rate, rate at which photons redshift past the Lyman-n line without being absorbed. We use the convention that n=2 is Ly-alpha, n=3 is Ly-beta... Parameters: ----------- n : float Requested Lyman transition level, should be greater than 2. H : float Hubble parameter in s^-1. nH : float Hydrogen number density in cm^-3. x1s : float Fraction of 1s bound hydrogen. Returns: -------- RLyn : float Rate of escape of Lyman-n level. """ lambda_lya = 2.*jnp.pi*hbar*c / recomb_functions.rydberg * 4./3. # Lyman-alpha Wavelength RLya = 8.*jnp.pi*H/3./nH/x1s/lambda_lya**3 # Rate of escape of Lyman-alpha RLyn = (4*(n**2-1)/3/n**2)**3 * RLya # (lambda_lya/lambda_lyn)^3 * RLya return RLyn
[docs] def TLA_xe_deriv(self, lna, state, args): """ Compute coupled derivatives for ionization fraction and matter temperature using Peebles three-level atom. Parameters: ----------- lna : float Log scale factor state : array Current state [xe, Tm] args : tuple h, omega_b, omega_cdm, Neff, YHe; the Hubble parameter, the baryon denisty Omega_b h^2, the CDM density Omega_cdm h^2, the effecgive number of neutrinos, and the helium fraction Returns: -------- array Time derivatives [dxe/dlna, dTm/dlna] (units: dimensionless, eV) """ xe, Tm = state h, omega_b, omega_cdm, Neff, YHe = args xHII = xe # since everything else is fully recombined z = jnp.exp(-lna) - 1 omega_rad = cosmology.omega_rad0(Neff) nH = cosmology.nH(z, omega_b, YHe) H = cosmology.Hubble(z, h, omega_b, omega_cdm, omega_rad) TCMB = cosmology.TCMB(z) C = recomb_functions.peebles_C(z, xHII, H, nH) alpha = recomb_functions.alpha_H(Tm) beta = recomb_functions.beta_H(Tm) # dxe/d(lna) = (1/H) * dxe/dt dxe_dt = C * (beta * (1.0 - xe) - alpha * nH * xe**2) dxe_dloga = dxe_dt / H dTm_dloga = -2.0 * Tm + (recomb_functions.Gamma_compton(xe, TCMB, YHe) / H) * (TCMB - Tm) return jnp.array([dxe_dloga, dTm_dloga])
[docs] def solve_TLA(self, lna0, xe0, Tm0, h, omega_b, omega_cdm, Neff, YHe, rtol=1e-7, atol=1e-9, solver=Kvaerno3(), max_steps = 4096): """ Solve late-time TLA evolution. Integrates hydrogen recombination using Peebles TLA in the region beyond where SWIFT corrections are tabulated. Parameters: ----------- lna0 : float Starting log scale factor xe0 : float Initial ionization fraction Tm0: float Starting matter temperature h : float Hubble parameter omega_b : float The baryon density Omega_b h^2 omega_cdm : float The density of Cold Dark Matter Omega_cdm h^2 Neff : float Effective number of neutrinos YHe : float Helium fraction rtol : float, optional Relative tolerance (default: 1e-7) atol : float, optional Absolute tolerance (default: 1e-9) solver : diffrax.Solver, optional ODE solver (default: Tsit5()) max_steps : int, optional Maximum steps (default: 4096) Returns: -------- tuple (xe_output, Tm_output, lna_output) - ionization fraction, matter temperature, and log scale factor arrays """ t0 = lna0 t1 = jnp.inf # lna_axis.max # need to go at least twice max_steps to make sure we catch t1 t_arr = jnp.linspace(t0+self.integration_spacing, t0+2*max_steps*self.integration_spacing, 2*max_steps) save_at = SaveAt(ts=t_arr) # save_at = SaveAt(ts=lna_axis) # but start saving output at step 1 or later initial_state = jnp.array([xe0, Tm0]) term = ODETerm(self.TLA_xe_deriv) adjoint=ForwardMode() def lna_check(t, y, args, **kwargs): return t > self.lna_end # stop when true event = Event(lna_check) sol = diffeqsolve( term, solver, t0=t0, t1=t1, dt0=1e-3, y0=initial_state, args=(h, omega_b, omega_cdm, Neff, YHe), stepsize_controller=PIDController(rtol, atol),saveat=save_at, adjoint=adjoint, max_steps=max_steps, event=event ) xe_output = sol.ys[:, 0] Tm_output = sol.ys[:, 1] return xe_output, Tm_output, sol.ts