Source code for HyRex.hyrex

import numpy as np
import jax.numpy as jnp
from jax import jit, config
import equinox as eqx
from diffrax import Kvaerno3

from .hydrogen import hydrogen_model
from .helium import helium_model
from .array_with_padding import array_with_padding
config.update("jax_enable_x64", True)

[docs] class recomb_model(eqx.Module): """ Complete recombination model implementation. Combines helium and hydrogen recombination calculations with reionization modeling to compute full ionization history. Attributes: ----------- integration_spacing : float Step size in log scale factor for integration lna_axis_full : array Full log scale factor grid for recombination evolution z1 : float Final redshift for evolution twog_redshift : float Redshift at which two-photon processes become negligible He4equil_redshift : float Redshift threshold for HeII+III equilibrium phase idx_4He_equil : array Indices for HeII+III equilibrium phase on full grid Methods: -------- get_history : Compute complete recombination and reionization history (units: dimensionless) """ integration_spacing : jnp.float64 lna_axis_full : jnp.array z1 : jnp.float64 twog_redshift : jnp.float64 He4equil_redshift : jnp.float64 idx_4He_equil : jnp.array def __init__(self,integration_spacing = 5.0e-4, z0=8000., z1=0.): """ Initialize complete recombination model. Sets up time grids and parameters for helium recombination, hydrogen recombination, and reionization phases. Parameters: ----------- integration_spacing : float, optional Step size for integration (default: 5.0e-4) z0 : float, optional Initial redshift (default: 8000.) z1 : float, optional Final redshift (default: 0.) """ self.integration_spacing = integration_spacing self.z1 = z1 # Define time axes self.lna_axis_full = jnp.arange(-jnp.log(1+z0), -jnp.log(1+z1), self.integration_spacing) self.twog_redshift = 701. self.He4equil_redshift = 3601. # generous self.idx_4He_equil = jnp.where(self.lna_axis_full <= -jnp.log(self.He4equil_redshift))[0] @jit def __call__(self, h, omega_b, omega_cdm, Neff, YHe, z_reion = 11, Delta_z_reion = 0.5, z_reion_He = 3.5, Delta_z_reion_He = 0.5, exp_reion = 1.5, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): """ Compute complete recombination and reionization history. Parameters: ----------- h : float Hubble parameter omega_b : float The baryon density Omega_b h^2 omega_cdm : float The density of Cold Dark Matter Omega_cdm h^2 Neff : float Effective number of neutrinos YHe : float Helium fraction z_reion : float, optional Reionization redshift (default: 11) Delta_z_reion : float, optional Reionization transition width (default: 0.5) z_reion_He : float, optional Reionization redshift of singly-ionized helium (default: 3.5) Delta_z_reion_He : float, optional Reionization transition width for singly-ionized helium (default: 0.5) exp_reion : float, optional Power of 1+z appearing in tanh argument during reionization (default: 3/2) rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional Absolute tolerance for ODE solver (default: 1e-9) solver : diffrax.Solver, optional ODE solver instance (default: Kvaerno3()) max_steps : int, optional Maximum solver steps (default: 1024) Returns: -------- tuple (xe_full_reion, lna_full, Tm, lna_Tm) - complete ionization history with reionization, log scale factor, matter temperature, and temperature grid """ return self.get_history(h, omega_b, omega_cdm, Neff, YHe, z_reion, Delta_z_reion, z_reion_He, Delta_z_reion_He, exp_reion, rtol, atol, solver, max_steps) # do not jit, use call instead
[docs] def get_history(self, h, omega_b, omega_cdm, Neff, YHe, z_reion = 11, Delta_z_reion = 0.5, z_reion_He = 3.5, Delta_z_reion_He = 0.5, exp_reion=1.5, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): """ Compute complete recombination and reionization history. Parameters: ----------- h : float Hubble parameter omega_b : float The baryon density Omega_b h^2 omega_cdm : float The density of Cold Dark Matter Omega_cdm h^2 Neff : float Effective number of neutrinos YHe : float Helium fraction z_reion : float, optional Reionization redshift of hydorgen and neutral helium (default: 11) Delta_z_reion : float, optional Reionization transition width for hydorgen and neutral helium (default: 0.5) z_reion_He : float, optional Reionization redshift of singly-ionized helium (default: 3.5) Delta_z_reion_He : float, optional Reionization transition width for singly-ionized helium (default: 0.5) exp_reion : float, optional Power of 1+z appearing in tanh argument during reionization (default: 3/2) rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional Absolute tolerance for ODE solver (default: 1e-9) solver : diffrax.Solver, optional ODE solver instance (default: Kvaerno3()) max_steps : int, optional Maximum solver steps (default: 1024) Returns: -------- tuple (xe_full_reion, lna_full, Tm, lna_Tm) - complete ionization history with reionization, log scale factor, matter temperature, and temperature grid """ lna_axis_4Heequil = self.lna_axis_full[self.idx_4He_equil] xe_4He, lna_4He = helium_model(lna_axis_4Heequil)(h, omega_b, omega_cdm, Neff, YHe,) xe_full, lna_full, Tm, lna_Tm = hydrogen_model(xe_4He,lna_4He,-jnp.log(1+self.z1),lna_4He.lastval,self.twog_redshift)(h, omega_b, omega_cdm, Neff, YHe) ### Hydrogen Reionization ### # We patch a simple tanh solution to the tail of the electron fraction result. fHe = YHe / 4 / (1-YHe) z = 1/jnp.exp(lna_full.arr) - 1 y = (1+z)**(exp_reion) y_reion = (1+z_reion)**(exp_reion) Delta_y_reion = exp_reion * (1+z_reion)**(exp_reion-1) * Delta_z_reion tanh_arg = (y_reion - y) / Delta_y_reion xe_reion_correction = (1+fHe)/2 * (1 + jnp.tanh(tanh_arg)) ### Helium Reionization ### # The above accounts for hydrogen and the first ionization level of helium. # Let's also account for the second ionization of helium: tanh_arg_He = (z_reion_He - z)/Delta_z_reion_He xe_HeII_reion_correction = fHe/2 * (1 + jnp.tanh(tanh_arg_He)) xe_full_arr = xe_reion_correction + xe_HeII_reion_correction + xe_full.arr xe_full_reion = array_with_padding(xe_full_arr) ### End of Reionization ### # best return the whole array-with-padding object # so we can interpolate over the padding return (xe_full_reion, lna_full, Tm, lna_Tm)