Coverage for src / galsbi / ucat_sps / galaxy_population_models / galaxy_stellar_mass_function.py: 100%
165 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 16:31 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 16:31 +0000
1# Copyright (C) 2025 LMU Munich
2# Copyright (C) 2025 ETH Zurich, Institute for Particle Physics and Astrophysics
3# Author: Luca Tortorelli, Silvan Fischbacher
4# created: Apr 2025
6from collections import OrderedDict
8import numpy as np
9import scipy.integrate
10import scipy.optimize
11import scipy.special
12from cosmic_toolbox import logger
14LOGGER = logger.get_logger(__file__)
17def find_closest_ind(grid, vals):
18 ind = np.searchsorted(grid, vals)
19 ind[ind == grid.size] -= 1
20 ind[np.fabs(grid[ind] - vals) > np.fabs(grid[ind - 1] - vals)] -= 1
21 return ind
24def single_stellar_mass_function_parametrisation(
25 log10_stellar_masses, log_Mstar, phi_star, alpha
26):
27 """
28 This function parametrises the stellar mass function as a single Schechter
29 function. It computes the number density of galaxies as a function of their stellar
30 mass.
32 :param log10_stellar_masses: (array_like[n_mass_bins,]) log10 of the stellar masses
33 drawn from the SMF.
34 :param log_Mstar: (float) knee of the Schechter function. It corresponds to the
35 stellar mass at which the Schechter transitions from a simple power law with
36 slope alpha at lower masses into an exponential function at higher masses.
37 :param phi_star: (float) number density of galaxies at stellar mass M* (log_Mstar).
38 :param alpha: (float) faint-end slope.
39 :return stellar_mass_function: (array_like[n_mass_bins,]) the number density of
40 galaxies as a function of their stellar mass.
41 """
42 return (
43 np.log(10)
44 * np.exp(-(10 ** (log10_stellar_masses - log_Mstar)))
45 * phi_star
46 * (10 ** (log10_stellar_masses - log_Mstar)) ** (alpha + 1)
47 )
50def double_stellar_mass_function_parametrisation(
51 log10_stellar_masses,
52 log_Mstar,
53 phi_star_low_mass,
54 phi_star_high_mass,
55 alpha_low_mass,
56 alpha_high_mass,
57):
58 """
59 This function parametrises the stellar mass function as a double Schechter
60 function. It computes the number density of galaxies as a function of their stellar
61 mass.
63 :param log10_stellar_masses: (array_like[n_mass_bins,]) log10 of the stellar masses
64 drawn from the SMF.
65 :param log_Mstar: (float) knee of the Schechter function. It corresponds to the
66 stellar mass at which the Schechter transitions from a simple power law with
67 slope alpha at lower masses into an exponential function at higher masses. It
68 has the same value for both the low and high mass SMFs.
69 :param phi_star_low_mass: (float) number density of galaxies at stellar mass M*
70 (log_Mstar) for the low mass SMF.
71 :param phi_star_high_mass: (float) number density of galaxies at stellar mass M*
72 (log_Mstar) for the high mass SMF.
73 :param alpha_low_mass: (float) faint-end slope of the low mass SMF.
74 :param alpha_high_mass: (float) faint-end slope of the high mass SMF.
75 :return stellar_mass_function: (array_like[n_mass_bins,]) the number density of
76 galaxies as a function of their stellar mass.
77 """
78 return (
79 np.log(10)
80 * np.exp(-(10 ** (log10_stellar_masses - log_Mstar)))
81 * (
82 phi_star_low_mass
83 * (10 ** (log10_stellar_masses - log_Mstar)) ** (alpha_low_mass + 1)
84 + phi_star_high_mass
85 * (10 ** (log10_stellar_masses - log_Mstar)) ** (alpha_high_mass + 1)
86 )
87 )
90def stellar_mass_function_m_star(
91 z, parametrization, intercept, slope, quadratic_term=0
92):
93 if parametrization == "linexp":
94 m_s = np.polyval((slope, intercept), z)
95 elif parametrization == "logpower":
96 m_s = intercept + slope * np.log10(1 + z)
97 elif parametrization == "logquadratic_power":
98 m_s = (
99 intercept
100 + slope * np.log10(1 + z)
101 + quadratic_term * (np.log10(1 + z)) ** 2
102 )
103 else:
104 raise ValueError("unknown parametrization value")
105 return m_s
108def stellar_mass_function_phi_star(z, parametrization, amplitude, exp):
109 if parametrization == "linexp":
110 p_star = amplitude * np.exp(exp * z)
111 elif parametrization == "logpower" or parametrization == "logquadratic_power":
112 p_star = amplitude * (1 + z) ** exp
113 else:
114 raise ValueError("unknown parametrization value")
115 return p_star
118def upper_inc_gamma(a, x):
119 if a > 0:
120 uig = scipy.special.gamma(a) * scipy.special.gammaincc(a, x)
122 elif a == 0:
123 uig = -scipy.special.expi(-x)
125 else:
126 uig = 1 / a * (upper_inc_gamma(a + 1, x) - x**a * np.exp(-x))
128 return uig
131def initialize_stellar_mass_functions(
132 par, pixarea, cosmo, z_sm_intp=None, mass_key=None
133):
134 if mass_key is None:
135 raise KeyError
136 kw_smfun = dict(
137 pixarea=pixarea,
138 z_res=par.sm_fct_z_res,
139 sm_res=par.sm_fct_sm_res,
140 z_max=par.sm_fct_z_max,
141 sm_min=par.sm_fct_sm_min,
142 z_sm_intp=z_sm_intp,
143 ngal_multiplier=par.ngal_multiplier,
144 cosmo=cosmo,
145 )
146 if par.ngal_multiplier != 1:
147 LOGGER.warning(f"ngal_multiplier is set to {par.ngal_multiplier}")
149 sm_funcs = OrderedDict()
150 for i, g in enumerate(par.galaxy_types):
151 if g in par.galaxy_types:
152 sm_funcs[g] = StellarMassFunction(
153 sm_fct_parametrization=par.sm_fct_parametrization,
154 sm_m_star_0=getattr(par, f"sm_fct_m_star_{g}_0_{mass_key}"),
155 sm_m_star_1=getattr(par, f"sm_fct_m_star_{g}_1_{mass_key}"),
156 sm_m_star_2=getattr(par, f"sm_fct_m_star_{g}_2_{mass_key}"),
157 sm_phi_star_amp=getattr(par, f"sm_fct_phi_star_{g}_amp_{mass_key}"),
158 sm_phi_star_exp=getattr(par, f"sm_fct_phi_star_{g}_exp_{mass_key}"),
159 sm_alpha=getattr(par, f"sm_fct_alpha_{g}_{mass_key}"),
160 name=g,
161 galaxy_type=i,
162 seed_ngal=par.seed_ngal,
163 **kw_smfun,
164 )
165 return sm_funcs
168def maximum_redshift(
169 z_sm_intp, sm_min, z_max, parametrization, alpha, sm_m_star_par, seed_ngal
170):
171 """
172 Computes the maximum redshift up to which we sample objects from the luminosity
173 function.
175 The cutoff is based on the criterion that the CDF for absolute magnitudes is larger
176 than 1e-5, i.e. that there is a reasonable probability of actually obtaining objects
177 at this redshift and absolute magnitude which still pass the cut on
178 par.gals_mag_max.
179 """
180 if z_sm_intp is None:
181 return z_max
183 def cond_mag_cdf_lim(z):
184 sm_m_s = stellar_mass_function_m_star(z, parametrization, *sm_m_star_par)
185 cdf_lim = (
186 upper_inc_gamma(alpha + 1, 10 ** (0.4 * (z_sm_intp(z) - sm_m_s)))
187 / upper_inc_gamma(alpha + 1, 10 ** (0.4 * (sm_min - sm_m_s)))
188 - 1e-5
189 )
190 return cdf_lim
192 try:
193 z_max_cutoff = scipy.optimize.brentq(cond_mag_cdf_lim, 0, z_sm_intp.x[-1])
194 except ValueError:
195 z_max_cutoff = z_sm_intp.x[-1]
197 z_max = min(z_max, z_max_cutoff)
198 np.random.seed(seed_ngal)
199 z_max += np.random.uniform(0, 0.0001)
201 return z_max
204class NumGalCalculator:
205 """
206 Computes galaxy number counts by integrating the galaxy stellar mass function
207 (Schechter function).
209 The integral over stellar masses can be done analytically, while the integral
210 over redshifts is computed numerically. See also
211 docs/jupyter_notebooks/sample_redshift_magnitude.ipynb.
212 """
214 def __init__(
215 self,
216 z_max,
217 sm_min,
218 parametrization,
219 sm_alpha,
220 sm_m_star_par,
221 sm_phi_star_par,
222 cosmo,
223 pixarea,
224 ngal_multiplier=1,
225 ):
226 z_density_int = scipy.integrate.quad(
227 func=self._redshift_density,
228 a=0,
229 b=z_max,
230 args=(
231 sm_min,
232 parametrization,
233 sm_alpha,
234 sm_m_star_par,
235 sm_phi_star_par,
236 cosmo,
237 ),
238 )[0]
239 self.n_gal_mean = int(round(z_density_int * pixarea * ngal_multiplier))
241 def __call__(self):
242 if self.n_gal_mean > 0:
243 n_gal = np.random.poisson(self.n_gal_mean)
244 else:
245 n_gal = self.n_gal_mean
246 return n_gal
248 def _redshift_density(
249 self,
250 z,
251 sm_min,
252 parametrization,
253 sm_alpha,
254 par_sm_m_star,
255 par_sm_phi_star,
256 cosmo,
257 ):
258 sm_m_star = stellar_mass_function_m_star(z, parametrization, *par_sm_m_star)
259 sm_phi_star = stellar_mass_function_phi_star(
260 z, parametrization, *par_sm_phi_star
261 )
262 e = np.sqrt(cosmo.params.omega_m * (1 + z) ** 3 + cosmo.params.omega_l)
263 d_h = cosmo.params.c / cosmo.params.H0
264 d_m = cosmo.background.dist_trans_a(a=1 / (1 + z))
265 density = (
266 sm_phi_star
267 * d_h
268 * d_m**2
269 / e
270 * upper_inc_gamma(sm_alpha + 1, 10 ** (sm_min - sm_m_star))
271 )
273 return density
276class RedshiftSMSampler:
277 """
278 Samples redshifts and stellar masses from the galaxy stellar mass function.
280 The sampling is done by first drawing redshifts from the redshift-pdf obtained by
281 integrating out stellar masses. Then, we sample stellar masses from the
282 conditional pdfs obtained by conditioning the stellar mass function on the sampled
283 redshifts (the conditional pdf is different for each redshift). See also
284 docs/jupyter_notebooks/sample_redshift_magnitude.ipynb and
285 docs/jupyter_notebooks/test_self_consistency.ipynb.
286 """
288 def __init__(
289 self,
290 z_res,
291 z_max,
292 sm_res,
293 sm_min,
294 parametrization,
295 sm_alpha,
296 sm_m_star_par,
297 sm_phi_star_par,
298 cosmo,
299 ):
300 self.z_res = z_res
301 self.z_max = z_max
302 self.sm_res = sm_res
303 self.sm_min = sm_min
304 self.parametrization = parametrization
305 self.sm_alpha = sm_alpha
306 self.sm_m_star_par = sm_m_star_par
307 self.sm_phi_star_par = sm_phi_star_par
308 self.cosmo = cosmo
310 self._setup_redshift_grid(
311 z_max, z_res, sm_min, sm_alpha, sm_m_star_par, sm_phi_star_par, cosmo
312 )
313 self._setup_mag_grid(
314 z_max, sm_min, sm_res, parametrization, sm_alpha, sm_m_star_par
315 )
317 def __call__(self, n_samples):
318 z = np.random.choice(self.z_grid, size=n_samples, replace=True, p=self.nz_grid)
320 sm_m_s = stellar_mass_function_m_star(
321 z, self.parametrization, *self.sm_m_star_par
322 )
323 sm_m_rvs = np.random.uniform(
324 low=0,
325 high=upper_inc_gamma(self.sm_alpha + 1, 10 ** (self.sm_min - sm_m_s)),
326 size=n_samples,
327 ) # here, we sample M - M*, where M* is redshift-dependent
328 uig_ind = find_closest_ind(self.uig_grid, sm_m_rvs)
329 sm = sm_m_s + self.m_s__m__grid[uig_ind] # now we transform from M - M* to M
330 return z, sm
332 def _setup_redshift_grid(
333 self, z_max, z_res, sm_min, sm_alpha, sm_m_star_par, sm_phi_star_par, cosmo
334 ):
335 self.z_grid = np.linspace(
336 z_res, z_max, num=int(round((z_max - z_res) / z_res)) + 1
337 )
338 e = np.sqrt(
339 cosmo.params.omega_m * (1 + self.z_grid) ** 3 + cosmo.params.omega_l
340 )
341 d_h = cosmo.params.c / cosmo.params.H0
342 d_m = cosmo.background.dist_trans_a(a=1 / (1 + self.z_grid))
343 f = d_h * d_m**2 / e
345 sm_m_star = stellar_mass_function_m_star(
346 self.z_grid, self.parametrization, *sm_m_star_par
347 )
348 sm_phi_star = stellar_mass_function_phi_star(
349 self.z_grid, self.parametrization, *sm_phi_star_par
350 )
352 self.nz_grid = (
353 f * sm_phi_star * upper_inc_gamma(sm_alpha + 1, 10 ** (sm_min - sm_m_star))
354 )
355 self.nz_grid /= np.sum(self.nz_grid)
357 def _setup_mag_grid(
358 self, z_max, sm_min, sm_res, parametrization, sm_alpha, sm_m_star_par
359 ):
360 sm_m_s_max = -scipy.optimize.minimize_scalar(
361 lambda z: -stellar_mass_function_m_star(z, parametrization, *sm_m_star_par),
362 bounds=(0, z_max),
363 method="bounded",
364 ).fun
365 sm_m_s_min = scipy.optimize.minimize_scalar(
366 lambda z: stellar_mass_function_m_star(z, parametrization, *sm_m_star_par),
367 bounds=(0, z_max),
368 method="bounded",
369 ).fun
371 sm_max = sm_min
372 while upper_inc_gamma(sm_alpha + 1, 10 ** (sm_max - sm_m_s_min)) > 0:
373 sm_max += 0.1
375 self.m_s__m__grid = np.linspace(
376 sm_max - sm_m_s_min,
377 sm_min - sm_m_s_max,
378 num=int(round((sm_max - sm_m_s_min - sm_min + sm_m_s_max) / sm_res)) + 1,
379 )
380 self.uig_grid = upper_inc_gamma(sm_alpha + 1, 10**self.m_s__m__grid)
383class StellarMassFunction:
384 """
385 Stellar Mass function.
386 """
388 def __init__(
389 self,
390 name,
391 sm_fct_parametrization,
392 sm_m_star_0,
393 sm_m_star_1,
394 sm_m_star_2,
395 sm_phi_star_amp,
396 sm_phi_star_exp,
397 sm_alpha,
398 cosmo,
399 pixarea,
400 galaxy_type,
401 seed_ngal,
402 z_res=0.001,
403 sm_res=0.01,
404 z_max=3,
405 sm_min=8,
406 z_sm_intp=None,
407 ngal_multiplier=1,
408 ):
409 self.parametrization = sm_fct_parametrization
410 self.sm_m_star_0 = sm_m_star_0
411 self.sm_m_star_1 = sm_m_star_1
412 self.sm_m_star_2 = sm_m_star_2
413 self.sm_phi_star_amp = sm_phi_star_amp
414 self.sm_phi_star_exp = sm_phi_star_exp
415 self.sm_m_star_par = sm_m_star_0, sm_m_star_1, sm_m_star_2
416 self.sm_phi_star_par = sm_phi_star_amp, sm_phi_star_exp
417 self.sm_alpha = sm_alpha
418 self.z_sm_intp = z_sm_intp
419 self.sm_min = sm_min
420 self.cosmo = cosmo
421 self.pixarea = pixarea
422 self.z_res = z_res
423 self.name = name
424 self.galaxy_type = galaxy_type
426 self.z_max = maximum_redshift(
427 z_sm_intp=z_sm_intp,
428 sm_min=self.sm_min,
429 z_max=z_max,
430 parametrization=self.parametrization,
431 alpha=self.sm_alpha,
432 sm_m_star_par=self.sm_m_star_par,
433 seed_ngal=seed_ngal,
434 )
436 self.n_gal_calc = NumGalCalculator(
437 z_max=self.z_max,
438 sm_min=sm_min,
439 parametrization=sm_fct_parametrization,
440 sm_alpha=sm_alpha,
441 sm_m_star_par=self.sm_m_star_par,
442 sm_phi_star_par=self.sm_phi_star_par,
443 cosmo=cosmo,
444 pixarea=pixarea,
445 ngal_multiplier=ngal_multiplier,
446 )
448 self.z_sm_sampler = RedshiftSMSampler(
449 z_res=z_res,
450 z_max=self.z_max,
451 parametrization=sm_fct_parametrization,
452 sm_res=sm_res,
453 sm_min=sm_min,
454 sm_alpha=sm_alpha,
455 sm_m_star_par=self.sm_m_star_par,
456 sm_phi_star_par=self.sm_phi_star_par,
457 cosmo=cosmo,
458 )
460 def sample_z_sm_and_apply_cut(self, seed_ngal, n_gal_max=np.inf, size_chunk=10000):
461 """
462 This function gets the abs mag and z using chunking, which uses less memory
463 than the original method.
465 It does not give exactly the same result as before due to different order of
466 random draws in z_mabs_sampler, but it's the same sample.
467 """
468 np.random.seed(seed_ngal)
469 n_gal = self.n_gal_calc()
470 if n_gal == 0:
471 return np.array([]), np.array([])
472 n_chunks = int(np.ceil(float(n_gal) / float(size_chunk)))
473 list_sm = []
474 list_z = []
475 n_gal_final = 0
476 for ic in range(n_chunks):
477 n_gal_sample = n_gal % size_chunk if ic + 1 == n_chunks else size_chunk
478 z_chunk, sm_chunk = self.z_sm_sampler(n_gal_sample)
480 # Apply cut in z - M - plane
481 if self.z_sm_intp is not None:
482 select = sm_chunk < self.z_sm_intp(z_chunk)
483 z_chunk = z_chunk[select]
484 sm_chunk = sm_chunk[select]
486 n_gal_final += len(z_chunk)
487 list_z.append(z_chunk)
488 list_sm.append(sm_chunk)
490 if n_gal_final > n_gal_max:
491 break
493 z = np.hstack(list_z)
494 sm = np.hstack(list_sm)
495 return sm, z