import jax.numpy as jnp
from jax import jit, config, lax, grad
import equinox as eqx
from diffrax import diffeqsolve, SaveAt, ODETerm, Kvaerno3, PIDController, ForwardMode, Event
from . import cosmology
from .cosmology import kB, not4
from . import recomb_functions
from .array_with_padding import array_with_padding
config.update("jax_enable_x64", True)
[docs]
class helium_model(eqx.Module):
"""
Helium recombination model implementation.
Computes helium ionization fraction evolution through multiple phases:
HeII+III equilibrium, post-Saha HeII expansion, and full HeII recombination.
Attributes:
-----------
integration_spacing : float
Step size in log scale factor for integration
lna_axis_4Heequil : array
Log scale factor grid for HeII+III equilibrium phase
concrete_axis_size : array
Pre-allocated array for HeII full recombination phase
concrete_axis_size_postSahaHe : array
Pre-allocated array for post-Saha HeII expansion phase
Methods:
--------
get_helium_history : Compute full helium recombination history (units: dimensionless)
xesaha_HeII_III : Compute HeII+III equilibrium phase (units: dimensionless)
post_saha_xHeII : Compute post-Saha HeII expansion (units: dimensionless)
solve_HeII_full : Solve full HeII recombination (units: dimensionless)
xHeII_post_Saha : Compute HeII fraction in post-Saha regime (units: dimensionless)
xH1_Saha : Compute neutral hydrogen fraction in Saha equilibrium (units: dimensionless)
helium_dxHeIIdlna : Compute HeII recombination rate (units: dimensionless)
xe_derivative_HeII : Compute HeII derivative for ODE integration (units: dimensionless)
"""
integration_spacing : jnp.float64
lna_axis_4Heequil : jnp.array
concrete_axis_size : jnp.array
concrete_axis_size_postSahaHe : jnp.array
def __init__(self,lna_axis_4Heequil,integration_spacing = 5.0e-4, Nsteps=800, Nsteps_postSahaHe=4000,z0=8000., z1=20.):
"""
Initialize helium recombination model.
Parameters:
-----------
lna_axis_4Heequil : array
Log scale factor grid for HeII+III equilibrium phase
integration_spacing : float, optional
Step size for integration (default: 5.0e-4)
Nsteps : int, optional
Maximum number of integration steps (default: 800)
Nsteps_postSahaHe : int, optional
Maximum steps for post-Saha HeII phase (default: 4000)
z0 : float, optional
Initial redshift (default: 8000.)
z1 : float, optional
Final redshift (default: 20.)
"""
self.integration_spacing = integration_spacing
self.concrete_axis_size_postSahaHe = jnp.zeros(Nsteps_postSahaHe)
# Define time axes
self.lna_axis_4Heequil = lna_axis_4Heequil
self.concrete_axis_size = jnp.zeros(Nsteps)
@jit
def __call__(self, h, omega_b, omega_cdm, Neff, YHe, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024):
"""
Compute helium recombination history.
Parameters:
-----------
h : float
Hubble parameter
omega_b : float
The baryon density Omega_b h^2
omega_cdm : float
The density of Cold Dark Matter Omega_cdm h^2
Neff : float
Effective number of neutrinos
YHe : float
Helium fraction
rtol : float, optional
Relative tolerance for ODE solver (default: 1e-6)
atol : float, optional
Absolute tolerance for ODE solver (default: 1e-9)
solver : diffrax.Solver, optional
ODE solver instance (default: Kvaerno3())
max_steps : int, optional
Maximum solver steps (default: 1024)
Returns:
--------
tuple
(xe_4He, lna_4He) - helium ionization fraction and log scale factor
"""
return self.get_helium_history(h, omega_b, omega_cdm, Neff, YHe, rtol, atol, solver, max_steps)
[docs]
def get_helium_history(self, h, omega_b, omega_cdm, Neff, YHe, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024):
"""
Compute complete helium recombination history through all phases.
Sequentially computes helium ionization fraction through HeII+III equilibrium,
post-Saha HeII expansion, and full HeII recombination phases.
Parameters:
-----------
h : float
Hubble parameter
omega_b : float
The baryon density Omega_b h^2
omega_cdm : float
The density of Cold Dark Matter Omega_cdm h^2
Neff : float
Effective number of neutrinos
YHe : float
Helium fraction
rtol : float, optional
Relative tolerance for ODE solver (default: 1e-6)
atol : float, optional
Absolute tolerance for ODE solver (default: 1e-9)
solver : diffrax.Solver, optional
ODE solver instance (default: Kvaerno3())
max_steps : int, optional
Maximum solver steps (default: 1024)
Returns:
--------
tuple
(xe_4He, lna_4He) containing helium ionization fraction evolution
and log scale factor grid
"""
# Compute xe at different phases
# Give it a large enough array of lna to work with
xe_output_4He_equil, lna_output_4He_equil = self.xesaha_HeII_III(self.lna_axis_4Heequil,omega_b, YHe)
xe_output_4He_equil_obj = array_with_padding(xe_output_4He_equil)
lna_output_4He_equil_obj = array_with_padding(lna_output_4He_equil)
# this one MUST start shifted by one redshift bin to avoid overlapping redshifts
xe_output_4He_postSaha, lna_output_4He_postSaha = self.post_saha_xHeII(lna_output_4He_equil_obj.lastval + self.integration_spacing, h, omega_b, omega_cdm, Neff, YHe)
xe_4He_equil_post = xe_output_4He_equil_obj.concat(array_with_padding(xe_output_4He_postSaha))
lna_4He_equil_post = lna_output_4He_equil_obj.concat(array_with_padding(lna_output_4He_postSaha))
xe_output_4He_full, lna_output_4He_full = self.solve_HeII_full(
lna_4He_equil_post.lastval, xe_4He_equil_post.lastval, h, omega_b, omega_cdm, Neff, YHe, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=4096)
xe_4He = xe_4He_equil_post.concat(array_with_padding(xe_output_4He_full))
lna_4He = lna_4He_equil_post.concat(array_with_padding(lna_output_4He_full))
return (xe_4He, lna_4He)
###################### HELIUM RECOMBINATION ######################
# HYREC-2's helium.c expects T in K, but we use eV instead, hence proliferation of
# kB's.
# High tempateratures (z >~ 4000). Function to calculate xe in He II + III equilibrium
# We use this form until xHeIII is 1e-9
[docs]
def xesaha_HeII_III(self, lna_axis, omega_b, YHe, threshold=1e-9):
"""
Compute xe in HeII+III equilibrium phase.
Calculates ionization fraction assuming equilibrium between HeII and HeIII
until HeIII fraction drops below threshold.
Parameters:
-----------
lna_axis : array
Log scale factor grid
omega_b : float
The baryon density Omega_b h^2
YHe : float
Helium fraction
threshold : float, optional
Threshold for HeIII fraction to stop calculation (default: 1e-9)
Returns:
--------
tuple
(xe_output, lna_output) - ionization fraction and log scale factor arrays
"""
# Pre-allocate xe_output
xe_output = jnp.ones_like(lna_axis)*jnp.inf
lna_output = jnp.ones_like(lna_axis)*jnp.inf
iz = int(0)
xe = 1. # arbitrary IC
stop = False
def compute_xe(carry):
xe_output, lna_output, xe, iz, stop = carry
lna = lna_axis[iz]
# Cosmological parameters
z = jnp.exp(-lna) - 1.
TCMB = cosmology.TCMB(z)
nH = cosmology.nH(z, omega_b, YHe)
# compute xHeIII
fHe = YHe/(1.-YHe)/not4 # abundance of helium by number
# Saha ratio xe * xHeIII / xHeII
s = 2.414194e15 * TCMB/kB * jnp.sqrt(TCMB/kB) * jnp.exp(-631462.7 / (TCMB/kB)) / nH
xHeIII = 2 * s * fHe / (1 + s + fHe) / (1 + jnp.sqrt(1 + 4 * s * fHe / (1 + s + fHe)**2))
xe = 1 + fHe + xHeIII
# Store current xe value in the output array
xe_output = xe_output.at[iz].set(xe)
lna_output = lna_output.at[iz].set(lna)
# Check difference
stop = xHeIII < threshold # Stop when xHeIII > threshold
# Increment index
iz = iz + 1
return (xe_output, lna_output, xe, iz, stop)
def stop_condition(state):
_, _, _, iz, stop = state
return (iz < lna_axis.size) & (~stop) # Continue until stop condition is met or we run out of space
# Initial state: (xe_output, xe, iz, stop flag)
initial_state = (xe_output, lna_output, xe, iz, stop)
# Run the while loop until the stop condition is met
final_state = lax.while_loop(stop_condition, compute_xe, initial_state)
# Unpack the final state
xe_output_final, lna_output_final, _, _, _ = final_state
# Return the electron fraction array and lna
return xe_output_final, lna_output_final
[docs]
def xHeII_post_Saha(self, lna, omega_b, YHe):
"""
Compute HeII fraction in post-Saha regime.
Calculates HeII fraction using Saha equilibrium between HeI and HeII.
Parameters:
-----------
lna : float
Log scale factor
h : float
Hubble parameter
omega_b : float
The baryon density Omega_b h^2
omega_cdm : float
The density of Cold Dark Matter Omega_cdm h^2
Neff : float
Effective number of neutrinos
YHe : float
Helium fraction
Returns:
--------
float
HeII fraction (units: dimensionless)
"""
fHe = YHe/(1.-YHe)/not4
z = jnp.exp(-lna) - 1.
TCMB = cosmology.TCMB(z)
nH = cosmology.nH(z, omega_b, YHe)
# Saha ratio xe * xHeII / xHeI
s = 4 * 2.414194e15 * TCMB/kB * jnp.sqrt(TCMB/kB) * jnp.exp(-285325. / (TCMB/kB)) / nH
xHeII = 2 * s * fHe / (1 + s) / (1 + jnp.sqrt(1 + 4 * s * fHe / (1 + s)**2))
return xHeII
[docs]
def xH1_Saha(self, lna, omega_b, YHe):
"""
Compute neutral hydrogen fraction in Saha equilibrium.
Calculates neutral hydrogen fraction assuming Saha equilibrium
between hydrogen ionization and recombination.
Parameters:
-----------
lna : float
Log scale factor
omega_b : float
The baryon density Omega_b h^2
YHe : float
Helium fraction
Returns:
--------
float
Neutral hydrogen fraction (units: dimensionless)
"""
z = jnp.exp(-lna) - 1.
TCMB = cosmology.TCMB(z)
nH = cosmology.nH(z, omega_b, YHe)
xHeII = self.xHeII_post_Saha(lna, omega_b, YHe)
s = 2.4127161187130e15* TCMB/kB * jnp.sqrt(TCMB/kB)*jnp.exp(-157801.37882/(TCMB/kB))/nH
xH1 = jnp.where(s>1e5,(1.+xHeII)/s - (xHeII**2 + 3.*xHeII + 2.)/s**2,\
jnp.where(s==0,1,1.-2./(1.+ xHeII/s + jnp.sqrt((1.+ xHeII/s)*(1.+ xHeII/s) +4./s))))
return xH1
# xHeII near equilibrium
# Returns xHeII-->xe (98 of 1011.3758). Integrate until delta x_e ~ 1e-5
[docs]
def post_saha_xHeII(self, starting_lna, h, omega_b, omega_cdm, Neff, YHe, threshold=1e-5):
"""
Compute post-Saha HeII expansion phase.
Calculates ionization fraction including corrections to Saha equilibrium
until deviations exceed threshold.
Parameters:
-----------
starting_lna : float
Initial log scale factor
h : float
Hubble parameter
omega_b : float
The baryon density Omega_b h^2
omega_cdm : float
The density of Cold Dark Matter Omega_cdm h^2
Neff : float
Effective number of neutrinos
YHe : float
Helium fraction
threshold : float, optional
Threshold for deviation from Saha (default: 1e-5)
Returns:
--------
tuple
(xe_output, lna_output) - ionization fraction and log scale factor arrays
"""
# Pre-allocate xe_output
xe_output = jnp.ones_like(self.concrete_axis_size_postSahaHe)*jnp.inf
lna_output = jnp.ones_like(self.concrete_axis_size_postSahaHe)*jnp.inf
iz = int(0)
xe = 1. # arbitrary
stop = False
def compute_xe(carry):
xe_output, lna_output, xe, iz, stop = carry
lna = starting_lna + iz*self.integration_spacing
xHeII = self.xHeII_post_Saha(lna, omega_b, YHe)
xH1 = self.xH1_Saha(lna, omega_b, YHe)
xe_saha = 1 - xH1 + xHeII
# Do post saha expansion. Assume all hydrogen is ionized.
grad_dxedlna_func = grad(self.helium_dxHeIIdlna, argnums=0)
grad_dxedlna = grad_dxedlna_func(xe_saha, lna, h, omega_b, omega_cdm, Neff, YHe)
dxe_Saha_dlna = grad(self.xHeII_post_Saha,argnums=0)(lna, omega_b, YHe)
xe = xe_saha + dxe_Saha_dlna / grad_dxedlna
# Store current xe value in the output array
xe_output = xe_output.at[iz].set(xe)
lna_output = lna_output.at[iz].set(lna)
# Check difference
diff = jnp.abs(xe_saha - xe)
stop = diff > threshold # Stop when diff < threshold
# Increment index
iz = iz + 1
return (xe_output, lna_output, xe, iz, stop)
def stop_condition(state):
_, _, _, iz, stop = state
return (iz < self.concrete_axis_size_postSahaHe.size) & (~stop) # Continue until stop condition is met or we run out of space
# Initial state: (xe_output, xe, iz, stop flag)
initial_state = (xe_output, lna_output, xe, iz, stop)
# Run the while loop until the stop condition is met
final_state = lax.while_loop(stop_condition, compute_xe, initial_state)
# Unpack the final state
xe_output_final, lna_output_final, _, _, _ = final_state
# Return the electron fraction array and the stopping `lna` value
return xe_output_final, lna_output_final
[docs]
def helium_dxHeIIdlna(self, xe, lna, h, omega_b, omega_cdm, Neff, YHe):
"""
Compute HeII recombination rate.
Calculates rate of change of HeII ionization fraction including
detailed atomic physics and escape probabilities.
Parameters:
-----------
xe : float
Current total ionization fraction
lna : float
Log scale factor
h : float
Hubble parameter
omega_b : float
The baryon density Omega_b h^2
omega_cdm : float
The density of Cold Dark Matter Omega_cdm h^2
Neff : float
Effective number of neutrinos
YHe : float
Helium fraction
Returns:
--------
float
HeII recombination rate dxHeII/dlna (units: dimensionless)
"""
fHe = YHe/(1.-YHe)/not4 # abundance of helium by number
# cosmology
#lna = -jnp.log(1+z)
z = jnp.exp(-lna) - 1.
TCMB = cosmology.TCMB(z)
nH = cosmology.nH(z, omega_b, YHe)# hydrogen number density, 1/cm^3
H = cosmology.Hubble(z, h, omega_b, omega_cdm, cosmology.omega_rad0(Neff)) # Hubble parameter, 1/s
GammaC = recomb_functions.Gamma_compton(xe, TCMB, YHe) # Compton scattering rate, 1/s
# compute xH1 in Saha equilibrium, xHeII in post-saha
xH1 = self.xH1_Saha(lna, omega_b, YHe)
# use xe = xHeII + (1.-xH1)
xHeII = xe - (1.-xH1)
# Saha ratio and abundances
s0 = 2.414194e15 * TCMB/kB * jnp.sqrt(TCMB/kB) / nH * 4
s = s0 * jnp.exp(-285325. / (TCMB/kB))
y2s = jnp.exp(46090. / (TCMB/kB)) / s0
y2p = jnp.exp(39101. / (TCMB/kB)) / s0 * 3
# Continuum opacity and optical depth
etacinv = 9.15776e22 * H / (nH* xH1)
g2pinc = (
1.976e6 / (1 - jnp.exp(-6989. / (TCMB/kB))) +
6.03e6 / (jnp.exp(19754. / (TCMB/kB)) - 1) +
1.06e8 / (jnp.exp(21539. / (TCMB/kB)) - 1) +
2.18e6 / (jnp.exp(28496. / (TCMB/kB)) - 1) +
3.37e7 / (jnp.exp(29224. / (TCMB/kB)) - 1) +
1.04e6 / (jnp.exp(32414. / (TCMB/kB)) - 1) +
1.51e7 / (jnp.exp(32781. / (TCMB/kB)) - 1)
)
# Optical depth and escape probability
tau2p = jnp.float64(4.277e-8 * nH / H * (fHe - xHeII))
dnuline = g2pinc * tau2p / (4 * jnp.pi**2)
tauc = dnuline / etacinv
enh = jnp.sqrt(1 + jnp.pi**2 * tauc) + 7.74 * tauc / (1 + 70 * tauc)
pesc = enh / tau2p
# As per HYREC-2 https://github.com/nanoomlee/HYREC-2/blob/09e8243d0e08edd3603a94dfbc445ae06cafe139/helium.c#L163:
# Effective increase in escape probability via intercombination line
# ratio of optical depth to allowed line = 1.023e-7
# 1-e^-tau23 = absorption prob. in intercom line
# e^[(E21P-E23P)/T] - e^[-(E21P-E23P)*eta_c] = step in intercom line
# relative to N_line in 21P
# divide by tau2p to get effective increase in escape prob.
# factor of 0.964525 is phase space factor for intercom vs allowed line -- (584/591)^3
pesc = pesc + (1.-jnp.exp(-1.023e-7*tau2p))*(0.964525*jnp.exp(2947./(TCMB/kB))-enh*jnp.exp(-6.14e13/etacinv))/tau2p
# Total decay rate
ydown = (50.94 * y2s) + (1.7989e9 * y2p * pesc)
# Recombination rate
return ydown * ((fHe - xHeII) * s - xHeII * (xHeII + 1 - xH1)) / H
[docs]
def xe_derivative_HeII(self, lna, state, args):
"""
Compute HeII derivative for ODE integration.
Derivative function for HeII ionization fraction evolution
used in ODE integration with diffrax.
Parameters:
-----------
lna : float
Log scale factor
state : float
Current HeII ionization state
args : tuple
h, omega_b, omega_cdm, Neff, YHe, omega_rad; the Hubble parameter,
the baryon denisty Omega_b h^2, the CDM density Omega_cdm h^2,
the effecgive number of neutrinos, the helium fraction, and
the radiation energy density (determined by Neff and can be
computed by cosmology.omega_rad0(Neff))
Returns:
--------
float
Time derivative of HeII fraction (units: dimensionless)
"""
h, omega_b, omega_cdm, Neff, YHe = args
#z = 1. / jnp.exp(lna) - 1.
# use xe = xHeII + (1.-xH1)
xe = state + self.xH1_Saha(lna, omega_b, YHe)
return self.helium_dxHeIIdlna(xe, lna, h, omega_b, omega_cdm, Neff, YHe)
[docs]
def solve_HeII_full(self, starting_lna, xe0, h, omega_b, omega_cdm, Neff, YHe, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024):
"""
Solve full HeII recombination evolution.
Integrates HeII recombination including detailed atomic physics
until HeII fraction becomes negligible.
Parameters:
-----------
starting_lna : float
Initial log scale factor
xe0 : float
Initial ionization fraction
h : float
Hubble parameter
omega_b : float
The baryon density Omega_b h^2
omega_cdm : float
The density of Cold Dark Matter Omega_cdm h^2
Neff : float
Effective number of neutrinos
YHe : float
Helium fraction
rtol : float, optional
Relative tolerance (default: 1e-6)
atol : float, optional
Absolute tolerance (default: 1e-9)
solver : diffrax.Solver, optional
ODE solver (default: Kvaerno3())
max_steps : int, optional
Maximum steps (default: 1024)
Returns:
--------
tuple
(xe_output, lna_output) - ionization fraction and log scale factor arrays
"""
# Initial conditions
TCMB_init = cosmology.TCMB(jnp.exp(-starting_lna) - 1.)
initial_state = jnp.array([xe0])
term = ODETerm(self.xe_derivative_HeII)
t0 = starting_lna
t1 = jnp.inf
# don't want to double count the boundary lna, so start saving after one step
t_arr = jnp.linspace(t0+self.integration_spacing, t0+max_steps*self.integration_spacing, max_steps)
save_at = SaveAt(ts=t_arr)
adjoint=ForwardMode()
def He_check(t, y, args, **kwargs):
lna = t
xH1 = self.xH1_Saha(lna, omega_b, YHe)
# use xe = xHeII + (1.-xH1)
# y is current state, [0] flattens jnp.array
xHeII = y[0] - (1.-xH1)
return xHeII < 1e-4
sol = diffeqsolve(
term, solver, t0=t0, t1=t1, dt0=1e-3,
y0=initial_state,
args=(h, omega_b, omega_cdm, Neff, YHe),
stepsize_controller=PIDController(rtol, atol),saveat=save_at,
event=Event(He_check),
adjoint=adjoint
)
xe_output = sol.ys[:, 0]
lna_output = sol.ts
return xe_output, lna_output