Skip to content

Instantly share code, notes, and snippets.

@syrte
Last active May 22, 2023 12:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save syrte/14af7f57fb3226a30e0f9eb436b3af6d to your computer and use it in GitHub Desktop.
Save syrte/14af7f57fb3226a30e0f9eb436b3af6d to your computer and use it in GitHub Desktop.
# -*- coding:utf-8 -*-
from __future__ import division
from numpy import sqrt, exp, log, abs, diff, sin, cos, linspace, interp
from numpy import empty_like, atleast_1d, asarray
from numpy.testing import assert_almost_equal
from scipy.optimize import bisect as root
from numpy import e, pi
from scipy.integrate import romberg
mid = lambda x: (x[1:] + x[:-1]) / 2.
__all__ = ['Cosmo']
class Cosmo:
def __init__(self, omegam0=0.2678, omegal0=None, omegab0=0.0444,
h0=0.71, sigma8=0.84, ns=0.93, Tcmb=2.73,
winfunc='tophat', transf='EH98', powspec=None,
romberg=True):
'''
Defaut values adopt from WMAP-1.
sigma8: sigma for 8 Mpc/h
winfunc: 'tophat', 'gauss' or 'sharpk'
transf: 'BBKS86', 'EH98' or tuple (k, tk)
powspec: None or tuple (k, pk)
if None, powspec will be calculated according to `transf` and `ns`
if not None, `transf` will be ignored.
'''
if powspec is not None:
self._interp_powspec(*powspec)
if transf == 'BBKS86':
self.transf = self.transf_BBKS86
elif transf == 'EH98':
self.transf = self.transf_EH98
else:
self._interp_transf(*transf)
assert winfunc in ['tophat', 'gauss', 'sharpk']
setattr(self, 'winfunc', getattr(self, 'winfunc_%s' % winfunc))
setattr(self, 'volume', getattr(self, 'volume_%s' % winfunc))
setattr(self, 'R_M', getattr(self, 'R_M_%s' % winfunc))
if romberg:
self.sigma2_R = self.sigma2_R_romberg
else:
self.sigma2_R = self.sigma2_R_simple
omegal0 = 1 - omegam0 if omegal0 is None else omegal0
self.omegam0 = omegam0
self.omegal0 = omegal0
self.omegak0 = 1 - omegam0 - omegal0
self.omegab0 = omegab0
self.h0 = h0
self.H0 = h0 * 100
self.sigma8 = sigma8
self.ns = ns
self.Tcmb = Tcmb
self.age0 = self.age(1)
self._D0 = 1
self._D0 = self.Dz(1)
assert_almost_equal(self.Dz(1), 1)
self._powspec_amp = 1e7
self._powspec_amp *= sigma8**2 / self.sigma2_R(8)
self._powspec_amp *= sigma8**2 / self.sigma2_R(8)
assert_almost_equal(self.sigma2_R(8), sigma8**2)
# calculate the powspec_amp twice to make sure the intergration converge
def omegam(self, a=1):
return self.omegam0 / a**3 / self.Ez2(a)
def omegal(self, a=1):
return self.omegal0 / self.Ez2(a)
def redshift(self, a=1):
return 1 / a - 1
def Ez(self, a=1):
return sqrt(self.Ez2(a))
def Ez2(self, a=1):
omegam0, omegal0, omegak0 = self.omegam0, self.omegal0, self.omegak0
return omegal0 + omegak0 / a**2 + omegam0 / a**3
def H(self, a=1):
'''Hubble constant, in physical unit'''
return self.H0 * self.Ez(a)
def rho_cr(self, a=1):
'''critical density, in unit (Msol/h)/(Mpc/h)**3'''
H0 = 100
G = 43007.1e-13
return 3 * H0**2 * self.Ez2(a) / (8 * pi * G)
def rho_mean(self, a=1):
'''mean density, in unit (Msol/h)/(Mpc/h)**3'''
return self.rho_cr(a) * self.omegam(a)
def delta_vir(self, a=1):
'''Overdensity for virial object, cf. Bryan & Noram 1998'''
x = self.omegam(a) - 1
return 18.0 * pi**2 + 82.0 * x - 39.0 * x**2
def rho_vir(self, a=1):
'''virial density, in unit (Msol/h)/(Mpc/h)**3'''
return self.delta_vir(a) * self.rho_cr(a)
def gz(self, a):
'''D(z) ∝ g(z)/(1 + z), Carroll et al. 1992'''
omegam, omegal = self.omegam(a), self.omegal(a)
return 2.5 * omegam / (omegam**(4 / 7.) - omegal +
(1 + omegam / 2.) * (1 + omegal / 70.))
def Dz(self, a=1):
'''linear growth rate D(z)'''
return self.gz(a) * a / self._D0
def delta_col(self, a=1):
'''linear overdensity for collased halo
'''
omegam = self.omegam(a)
if self.omegam0 == 1:
return 0.6 * (1.5 * pi)**(2 / 3.) * omegam**0.0185
else:
return 0.6 * (1.5 * pi)**(2 / 3.) * omegam**0.0055
def age(self, a=1):
#from astropy import units as u
# print(u.Mpc.to(u.km) * u.second.to(u.Gyr)) # 977.792
unit = 977.7922216731284
t = romberg(lambda a: 1 / self.H(a) / a, 1e-30, a)
return t * unit
def lookback_time(self, a=1):
return self.age0 - self.age(a)
def M_star(self, a=1):
'''
σ(M∗) = δc(t) = δc/D(t)
M in Msol/h
'''
a_list = atleast_1d(a)
m_list = empty_like(a_list, 'd')
for i, s in enumerate(a_list):
delta = self.delta_col(s) / self.Dz(s)
func = lambda logm: self.sigma_M(exp(logm), 1) - delta
m_list[i] = exp(root(func, -30, 40))
return m_list.reshape(asarray(a).shape)
def sigma_M(self, M, a=1):
'''M in Msol/h'''
return sqrt(self.sigma2_M(M, a))
def sigma2_M(self, M, a=1):
'''M in Msol/h'''
#R = (0.75 / pi / self.rho_mean(a) * M)**(1/3.) * 1e-3
R = self.R_M(M, a)
return self.sigma2_R(R, a)
def powspec(self, k, a=1):
'''k in h/Mpc'''
return self._powspec_amp * self.Dz(a)**2 * k**self.ns * self.transf(k)**2
def sigma_R(self, R, a=1):
'''R in Mpc/h'''
return sqrt(self.sigma2_R(R, a))
def _sigma2_R_dlogk(self, logk, R, a=1):
k = exp(logk)
return self.powspec(k, a) * self.winfunc(k, R)**2 * k**3
def sigma2_R_romberg(self, R, a=1):
'''R in Mpc/h'''
R_list = atleast_1d(R)
sigma2_list = empty_like(R_list, 'd')
for i, r in enumerate(R_list):
sigma2_list[i] = romberg(self._sigma2_R_dlogk, -40., 50., (r, a)) / (2 * pi**2)
return sigma2_list.reshape(asarray(R).shape)
def sigma2_R_simple(self, R, a=1):
'''R in Mpc/h, simple but fast integration'''
logk = linspace(-40, 50, 9001)
dlogk = diff(logk).reshape(-1, 1)
k = exp(mid(logk)).reshape(-1, 1)
return (self.powspec(k, a) *
self.winfunc(k, R)**2 * k**3 * dlogk).sum(0) / (2 * pi**2)
def transf_BBKS86(self, k):
'''k in h/Mpc'''
return Tk_BBKS86(k, h=self.h0, omegam0=self.omegam0, omegab0=self.omegab0)
def transf_EH98(self, k):
'''k in h/Mpc'''
return Tk_EH98(k, h=self.h0, omegam0=self.omegam0, omegab0=self.omegab0, Tcmb=self.Tcmb)
def winfunc_tophat(self, k, R):
x = k * R
return 3 * (sin(x) - x * cos(x)) / x**3.
def winfunc_gauss(self, k, R):
x = k * R
return exp(-x**2 / 2.)
def winfunc_sharpk(self, k, R):
return (abs(k * R) <= 1).astype('f')
def volume_tophat(self, R):
return 4 * pi / 3. * R**3
def volume_gauss(self, R):
return (2 * pi)**1.5 * R**3
def volume_sharpk(self, R):
return 6 * pi**2 * R**3
def R_M_tophat(self, M, a):
return (M / (4 * pi / 3 * self.rho_mean(a)))**(1 / 3.)
def R_M_gauss(self, M, a):
return (M / ((2 * pi)**1.5 * self.rho_mean(a)))**(1 / 3.)
def R_M_sharpk(self, M, a):
return (M / (6 * pi**2 * self.rho_mean(a)))**(1 / 3.)
def _interp_powspec(self, k, pk):
logk = log(k)
def func(self, k):
return interp(log(k), logk, pk, left=0, right=0) * self._powspec_amp
self.powspec = func
def _interp_transf(self, k, tk):
logk = log(k)
def func(self, k):
return interp(log(k), logk, tk, left=1, right=0)
self.transf = func
def mass_func(self, M, a=1):
'''n(M)dM number density in (Mpc/h)**-3, physical space
'''
sigma = self.sigma_M(M, a)
delta = self.delta_col(a)
dM = M * 1e-3
dlogs_dM = - (log(self.sigma_M(M + dM / 2, a))
- log(self.sigma_M(M - dM / 2, a))) / (dM * 1e10)
nu = delta / sigma
dF_dlogs = 1 / sqrt(2 * pi) * nu * exp(-nu**2 / 2) * 2
n = (self.rho_mean(a) * 1e9) / M * dF_dlogs * dlogs_dM
return n
def Tk_BBKS86(k, h=0.71, omegam0=0.2678, omegab0=0.0444, gamma=None):
'''k in Mpc/h'''
if gamma is not None:
gamma = gamma * 1.
else:
gamma = omegam0 * h * exp(-omegab0 * (1 + sqrt(2 * h) / omegam0))
q = k / gamma
return log(1 + 2.34 * q) / (2.34 * q) * (1 + 3.89 * q + (16.1 * q)**2 + (5.46 * q)**3 + (6.71 * q)**4)**-0.25
def Tk_EH98(k, h=0.71, omegam0=0.2678, omegab0=0.0444, Tcmb=2.73):
'''k in Mpc/h'''
k = k * h
f_baryon = omegab0 / omegam0
omhh = omegam0 * h**2
obhh = omhh * f_baryon
theta_cmb = Tcmb / 2.7
#-------------------------------
z_equality = 2.50e4 * omhh * theta_cmb**(-4) - 1.
k_equality = 0.0746 * omhh * theta_cmb**(-2)
z_drag = 0.313 * omhh**(-0.419) * (1 + 0.607 * omhh**0.674)
z_drag = 1 + z_drag * obhh**(0.238 * omhh**0.223)
z_drag = 1291 * omhh**(0.251) / (1 + 0.659 * omhh**0.828) * z_drag
R_drag = 31.5 * obhh * theta_cmb**(-4.) * 1000 / (1 + z_drag)
R_equality = 31.5 * obhh * theta_cmb**(-4.) * 1000 / (1 + z_equality)
sound_horizon = 2. / 3. / k_equality * sqrt(6. / R_equality)
sound_horizon *= log((sqrt(1 + R_drag) + sqrt(R_drag + R_equality)) / (1 + sqrt(R_equality)))
k_silk = 1.6 * obhh**(0.52) * omhh**(0.73) * (1e0 + (10.4 * omhh)**(-0.95))
alpha_c = ((46.9 * omhh)**0.670 * (1 + (32.1 * omhh)**(-0.532))) ** (-f_baryon)
alpha_c *= ((12.0 * omhh)**0.424 * (1 + (45.0 * omhh)**(-0.582))) ** (-f_baryon**3.)
beta_c = 0.944 / (1 + (458. * omhh)**(-0.708))
beta_c = beta_c * ((1 - f_baryon)**((0.395 * omhh)**(-0.0266)) - 1)
beta_c = 1 / (1 + beta_c)
y = (1 + z_equality) / (1 + z_drag)
alpha_b = y * (-6 * sqrt(1 + y) + (2 + 3 * y) * log((sqrt(1 + y) + 1) / (sqrt(1 + y) - 1)))
alpha_b = 2.07 * k_equality * sound_horizon * (1 + R_drag)**(-0.75) * alpha_b
beta_b = 0.5 + f_baryon + (3 - 2 * f_baryon) * sqrt((17.2 * omhh)**2 + 1)
beta_node = 8.41 * omhh**0.435
#-------------------------------
q = k / 13.41 / k_equality
ks = k * sound_horizon
tf_cdm = 1 / (1 + (ks / 5.4)**4)
tf_cdm = tf_cdm * _tf_pressureless(q, 1., beta_c) + \
(1 - tf_cdm) * _tf_pressureless(q, alpha_c, beta_c)
s_tilde = sound_horizon / (1. + (beta_node / ks)**3.)**(1 / 3.)
tf_baryon = _tf_pressureless(q, 1., 1.) / (1. + (ks / 5.2)**2.)
tf_baryon = tf_baryon + alpha_b / (1. + (beta_b / ks)**3) * exp(-(k / k_silk)**(1.4))
tf_baryon = tf_baryon * (sin(k * s_tilde) / (k * s_tilde))
tf_full = f_baryon * tf_baryon + (1 - f_baryon) * tf_cdm
return tf_full
def _tf_pressureless(q, a, b):
tf = log(e + 1.8 * b * q)
tf = tf / (tf + (14.2 / a + 386 / (1 + 69.9 * q**1.08)) * q**2)
return tf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment