import numpy as np
import jax.numpy as jnp
from jax import jit, config
import equinox as eqx
from diffrax import Kvaerno3
from .hydrogen import hydrogen_model
from .helium import helium_model
from .array_with_padding import array_with_padding
config.update("jax_enable_x64", True)
[docs]
class recomb_model(eqx.Module):
"""
Complete recombination model implementation.
Combines helium and hydrogen recombination calculations with
reionization modeling to compute full ionization history.
Attributes:
-----------
integration_spacing : float
Step size in log scale factor for integration
lna_axis_full : array
Full log scale factor grid for recombination evolution
z1 : float
Final redshift for evolution
twog_redshift : float
Redshift at which two-photon processes become negligible
He4equil_redshift : float
Redshift threshold for HeII+III equilibrium phase
idx_4He_equil : array
Indices for HeII+III equilibrium phase on full grid
Methods:
--------
get_history : Compute complete recombination and reionization history (units: dimensionless)
"""
integration_spacing : jnp.float64
lna_axis_full : jnp.array
z1 : jnp.float64
twog_redshift : jnp.float64
He4equil_redshift : jnp.float64
idx_4He_equil : jnp.array
def __init__(self,integration_spacing = 5.0e-4, z0=8000., z1=0.):
"""
Initialize complete recombination model.
Sets up time grids and parameters for helium recombination,
hydrogen recombination, and reionization phases.
Parameters:
-----------
integration_spacing : float, optional
Step size for integration (default: 5.0e-4)
z0 : float, optional
Initial redshift (default: 8000.)
z1 : float, optional
Final redshift (default: 0.)
"""
self.integration_spacing = integration_spacing
self.z1 = z1
# Define time axes
self.lna_axis_full = jnp.arange(-jnp.log(1+z0), -jnp.log(1+z1), self.integration_spacing)
self.twog_redshift = 701.
self.He4equil_redshift = 3601. # generous
self.idx_4He_equil = jnp.where(self.lna_axis_full <= -jnp.log(self.He4equil_redshift))[0]
@jit
def __call__(self, h, omega_b, omega_cdm, Neff, YHe, z_reion = 11, Delta_z_reion = 0.5, z_reion_He = 3.5, Delta_z_reion_He = 0.5, exp_reion = 1.5, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024):
"""
Compute complete recombination and reionization history.
Parameters:
-----------
h : float
Hubble parameter
omega_b : float
The baryon density Omega_b h^2
omega_cdm : float
The density of Cold Dark Matter Omega_cdm h^2
Neff : float
Effective number of neutrinos
YHe : float
Helium fraction
z_reion : float, optional
Reionization redshift (default: 11)
Delta_z_reion : float, optional
Reionization transition width (default: 0.5)
z_reion_He : float, optional
Reionization redshift of singly-ionized helium (default: 3.5)
Delta_z_reion_He : float, optional
Reionization transition width for singly-ionized helium (default: 0.5)
exp_reion : float, optional
Power of 1+z appearing in tanh argument during reionization (default: 3/2)
rtol : float, optional
Relative tolerance for ODE solver (default: 1e-6)
atol : float, optional
Absolute tolerance for ODE solver (default: 1e-9)
solver : diffrax.Solver, optional
ODE solver instance (default: Kvaerno3())
max_steps : int, optional
Maximum solver steps (default: 1024)
Returns:
--------
tuple
(xe_full_reion, lna_full, Tm, lna_Tm) - complete ionization history
with reionization, log scale factor, matter temperature, and temperature grid
"""
return self.get_history(h, omega_b, omega_cdm, Neff, YHe, z_reion, Delta_z_reion, z_reion_He, Delta_z_reion_He, exp_reion, rtol, atol, solver, max_steps)
# do not jit, use call instead
[docs]
def get_history(self, h, omega_b, omega_cdm, Neff, YHe, z_reion = 11, Delta_z_reion = 0.5, z_reion_He = 3.5, Delta_z_reion_He = 0.5, exp_reion=1.5, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024):
"""
Compute complete recombination and reionization history.
Parameters:
-----------
h : float
Hubble parameter
omega_b : float
The baryon density Omega_b h^2
omega_cdm : float
The density of Cold Dark Matter Omega_cdm h^2
Neff : float
Effective number of neutrinos
YHe : float
Helium fraction
z_reion : float, optional
Reionization redshift of hydorgen and neutral helium (default: 11)
Delta_z_reion : float, optional
Reionization transition width for hydorgen and neutral helium (default: 0.5)
z_reion_He : float, optional
Reionization redshift of singly-ionized helium (default: 3.5)
Delta_z_reion_He : float, optional
Reionization transition width for singly-ionized helium (default: 0.5)
exp_reion : float, optional
Power of 1+z appearing in tanh argument during reionization (default: 3/2)
rtol : float, optional
Relative tolerance for ODE solver (default: 1e-6)
atol : float, optional
Absolute tolerance for ODE solver (default: 1e-9)
solver : diffrax.Solver, optional
ODE solver instance (default: Kvaerno3())
max_steps : int, optional
Maximum solver steps (default: 1024)
Returns:
--------
tuple
(xe_full_reion, lna_full, Tm, lna_Tm) - complete ionization history
with reionization, log scale factor, matter temperature, and temperature grid
"""
lna_axis_4Heequil = self.lna_axis_full[self.idx_4He_equil]
xe_4He, lna_4He = helium_model(lna_axis_4Heequil)(h, omega_b, omega_cdm, Neff, YHe,)
xe_full, lna_full, Tm, lna_Tm = hydrogen_model(xe_4He,lna_4He,-jnp.log(1+self.z1),lna_4He.lastval,self.twog_redshift)(h, omega_b, omega_cdm, Neff, YHe)
### Hydrogen Reionization ###
# We patch a simple tanh solution to the tail of the electron fraction result.
fHe = YHe / 4 / (1-YHe)
z = 1/jnp.exp(lna_full.arr) - 1
y = (1+z)**(exp_reion)
y_reion = (1+z_reion)**(exp_reion)
Delta_y_reion = exp_reion * (1+z_reion)**(exp_reion-1) * Delta_z_reion
tanh_arg = (y_reion - y) / Delta_y_reion
xe_reion_correction = (1+fHe)/2 * (1 + jnp.tanh(tanh_arg))
### Helium Reionization ###
# The above accounts for hydrogen and the first ionization level of helium.
# Let's also account for the second ionization of helium:
tanh_arg_He = (z_reion_He - z)/Delta_z_reion_He
xe_HeII_reion_correction = fHe/2 * (1 + jnp.tanh(tanh_arg_He))
xe_full_arr = xe_reion_correction + xe_HeII_reion_correction + xe_full.arr
xe_full_reion = array_with_padding(xe_full_arr)
### End of Reionization ###
# best return the whole array-with-padding object
# so we can interpolate over the padding
return (xe_full_reion, lna_full, Tm, lna_Tm)