Coverage for src / galsbi / ucat_sps / galaxy_population_models / galaxy_stellar_mass_function.py: 100%

165 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-12 16:31 +0000

1# Copyright (C) 2025 LMU Munich 

2# Copyright (C) 2025 ETH Zurich, Institute for Particle Physics and Astrophysics 

3# Author: Luca Tortorelli, Silvan Fischbacher 

4# created: Apr 2025 

5 

6from collections import OrderedDict 

7 

8import numpy as np 

9import scipy.integrate 

10import scipy.optimize 

11import scipy.special 

12from cosmic_toolbox import logger 

13 

14LOGGER = logger.get_logger(__file__) 

15 

16 

17def find_closest_ind(grid, vals): 

18 ind = np.searchsorted(grid, vals) 

19 ind[ind == grid.size] -= 1 

20 ind[np.fabs(grid[ind] - vals) > np.fabs(grid[ind - 1] - vals)] -= 1 

21 return ind 

22 

23 

24def single_stellar_mass_function_parametrisation( 

25 log10_stellar_masses, log_Mstar, phi_star, alpha 

26): 

27 """ 

28 This function parametrises the stellar mass function as a single Schechter 

29 function. It computes the number density of galaxies as a function of their stellar 

30 mass. 

31 

32 :param log10_stellar_masses: (array_like[n_mass_bins,]) log10 of the stellar masses 

33 drawn from the SMF. 

34 :param log_Mstar: (float) knee of the Schechter function. It corresponds to the 

35 stellar mass at which the Schechter transitions from a simple power law with 

36 slope alpha at lower masses into an exponential function at higher masses. 

37 :param phi_star: (float) number density of galaxies at stellar mass M* (log_Mstar). 

38 :param alpha: (float) faint-end slope. 

39 :return stellar_mass_function: (array_like[n_mass_bins,]) the number density of 

40 galaxies as a function of their stellar mass. 

41 """ 

42 return ( 

43 np.log(10) 

44 * np.exp(-(10 ** (log10_stellar_masses - log_Mstar))) 

45 * phi_star 

46 * (10 ** (log10_stellar_masses - log_Mstar)) ** (alpha + 1) 

47 ) 

48 

49 

50def double_stellar_mass_function_parametrisation( 

51 log10_stellar_masses, 

52 log_Mstar, 

53 phi_star_low_mass, 

54 phi_star_high_mass, 

55 alpha_low_mass, 

56 alpha_high_mass, 

57): 

58 """ 

59 This function parametrises the stellar mass function as a double Schechter 

60 function. It computes the number density of galaxies as a function of their stellar 

61 mass. 

62 

63 :param log10_stellar_masses: (array_like[n_mass_bins,]) log10 of the stellar masses 

64 drawn from the SMF. 

65 :param log_Mstar: (float) knee of the Schechter function. It corresponds to the 

66 stellar mass at which the Schechter transitions from a simple power law with 

67 slope alpha at lower masses into an exponential function at higher masses. It 

68 has the same value for both the low and high mass SMFs. 

69 :param phi_star_low_mass: (float) number density of galaxies at stellar mass M* 

70 (log_Mstar) for the low mass SMF. 

71 :param phi_star_high_mass: (float) number density of galaxies at stellar mass M* 

72 (log_Mstar) for the high mass SMF. 

73 :param alpha_low_mass: (float) faint-end slope of the low mass SMF. 

74 :param alpha_high_mass: (float) faint-end slope of the high mass SMF. 

75 :return stellar_mass_function: (array_like[n_mass_bins,]) the number density of 

76 galaxies as a function of their stellar mass. 

77 """ 

78 return ( 

79 np.log(10) 

80 * np.exp(-(10 ** (log10_stellar_masses - log_Mstar))) 

81 * ( 

82 phi_star_low_mass 

83 * (10 ** (log10_stellar_masses - log_Mstar)) ** (alpha_low_mass + 1) 

84 + phi_star_high_mass 

85 * (10 ** (log10_stellar_masses - log_Mstar)) ** (alpha_high_mass + 1) 

86 ) 

87 ) 

88 

89 

90def stellar_mass_function_m_star( 

91 z, parametrization, intercept, slope, quadratic_term=0 

92): 

93 if parametrization == "linexp": 

94 m_s = np.polyval((slope, intercept), z) 

95 elif parametrization == "logpower": 

96 m_s = intercept + slope * np.log10(1 + z) 

97 elif parametrization == "logquadratic_power": 

98 m_s = ( 

99 intercept 

100 + slope * np.log10(1 + z) 

101 + quadratic_term * (np.log10(1 + z)) ** 2 

102 ) 

103 else: 

104 raise ValueError("unknown parametrization value") 

105 return m_s 

106 

107 

108def stellar_mass_function_phi_star(z, parametrization, amplitude, exp): 

109 if parametrization == "linexp": 

110 p_star = amplitude * np.exp(exp * z) 

111 elif parametrization == "logpower" or parametrization == "logquadratic_power": 

112 p_star = amplitude * (1 + z) ** exp 

113 else: 

114 raise ValueError("unknown parametrization value") 

115 return p_star 

116 

117 

118def upper_inc_gamma(a, x): 

119 if a > 0: 

120 uig = scipy.special.gamma(a) * scipy.special.gammaincc(a, x) 

121 

122 elif a == 0: 

123 uig = -scipy.special.expi(-x) 

124 

125 else: 

126 uig = 1 / a * (upper_inc_gamma(a + 1, x) - x**a * np.exp(-x)) 

127 

128 return uig 

129 

130 

131def initialize_stellar_mass_functions( 

132 par, pixarea, cosmo, z_sm_intp=None, mass_key=None 

133): 

134 if mass_key is None: 

135 raise KeyError 

136 kw_smfun = dict( 

137 pixarea=pixarea, 

138 z_res=par.sm_fct_z_res, 

139 sm_res=par.sm_fct_sm_res, 

140 z_max=par.sm_fct_z_max, 

141 sm_min=par.sm_fct_sm_min, 

142 z_sm_intp=z_sm_intp, 

143 ngal_multiplier=par.ngal_multiplier, 

144 cosmo=cosmo, 

145 ) 

146 if par.ngal_multiplier != 1: 

147 LOGGER.warning(f"ngal_multiplier is set to {par.ngal_multiplier}") 

148 

149 sm_funcs = OrderedDict() 

150 for i, g in enumerate(par.galaxy_types): 

151 if g in par.galaxy_types: 

152 sm_funcs[g] = StellarMassFunction( 

153 sm_fct_parametrization=par.sm_fct_parametrization, 

154 sm_m_star_0=getattr(par, f"sm_fct_m_star_{g}_0_{mass_key}"), 

155 sm_m_star_1=getattr(par, f"sm_fct_m_star_{g}_1_{mass_key}"), 

156 sm_m_star_2=getattr(par, f"sm_fct_m_star_{g}_2_{mass_key}"), 

157 sm_phi_star_amp=getattr(par, f"sm_fct_phi_star_{g}_amp_{mass_key}"), 

158 sm_phi_star_exp=getattr(par, f"sm_fct_phi_star_{g}_exp_{mass_key}"), 

159 sm_alpha=getattr(par, f"sm_fct_alpha_{g}_{mass_key}"), 

160 name=g, 

161 galaxy_type=i, 

162 seed_ngal=par.seed_ngal, 

163 **kw_smfun, 

164 ) 

165 return sm_funcs 

166 

167 

168def maximum_redshift( 

169 z_sm_intp, sm_min, z_max, parametrization, alpha, sm_m_star_par, seed_ngal 

170): 

171 """ 

172 Computes the maximum redshift up to which we sample objects from the luminosity 

173 function. 

174 

175 The cutoff is based on the criterion that the CDF for absolute magnitudes is larger 

176 than 1e-5, i.e. that there is a reasonable probability of actually obtaining objects 

177 at this redshift and absolute magnitude which still pass the cut on 

178 par.gals_mag_max. 

179 """ 

180 if z_sm_intp is None: 

181 return z_max 

182 

183 def cond_mag_cdf_lim(z): 

184 sm_m_s = stellar_mass_function_m_star(z, parametrization, *sm_m_star_par) 

185 cdf_lim = ( 

186 upper_inc_gamma(alpha + 1, 10 ** (0.4 * (z_sm_intp(z) - sm_m_s))) 

187 / upper_inc_gamma(alpha + 1, 10 ** (0.4 * (sm_min - sm_m_s))) 

188 - 1e-5 

189 ) 

190 return cdf_lim 

191 

192 try: 

193 z_max_cutoff = scipy.optimize.brentq(cond_mag_cdf_lim, 0, z_sm_intp.x[-1]) 

194 except ValueError: 

195 z_max_cutoff = z_sm_intp.x[-1] 

196 

197 z_max = min(z_max, z_max_cutoff) 

198 np.random.seed(seed_ngal) 

199 z_max += np.random.uniform(0, 0.0001) 

200 

201 return z_max 

202 

203 

204class NumGalCalculator: 

205 """ 

206 Computes galaxy number counts by integrating the galaxy stellar mass function 

207 (Schechter function). 

208 

209 The integral over stellar masses can be done analytically, while the integral 

210 over redshifts is computed numerically. See also 

211 docs/jupyter_notebooks/sample_redshift_magnitude.ipynb. 

212 """ 

213 

214 def __init__( 

215 self, 

216 z_max, 

217 sm_min, 

218 parametrization, 

219 sm_alpha, 

220 sm_m_star_par, 

221 sm_phi_star_par, 

222 cosmo, 

223 pixarea, 

224 ngal_multiplier=1, 

225 ): 

226 z_density_int = scipy.integrate.quad( 

227 func=self._redshift_density, 

228 a=0, 

229 b=z_max, 

230 args=( 

231 sm_min, 

232 parametrization, 

233 sm_alpha, 

234 sm_m_star_par, 

235 sm_phi_star_par, 

236 cosmo, 

237 ), 

238 )[0] 

239 self.n_gal_mean = int(round(z_density_int * pixarea * ngal_multiplier)) 

240 

241 def __call__(self): 

242 if self.n_gal_mean > 0: 

243 n_gal = np.random.poisson(self.n_gal_mean) 

244 else: 

245 n_gal = self.n_gal_mean 

246 return n_gal 

247 

248 def _redshift_density( 

249 self, 

250 z, 

251 sm_min, 

252 parametrization, 

253 sm_alpha, 

254 par_sm_m_star, 

255 par_sm_phi_star, 

256 cosmo, 

257 ): 

258 sm_m_star = stellar_mass_function_m_star(z, parametrization, *par_sm_m_star) 

259 sm_phi_star = stellar_mass_function_phi_star( 

260 z, parametrization, *par_sm_phi_star 

261 ) 

262 e = np.sqrt(cosmo.params.omega_m * (1 + z) ** 3 + cosmo.params.omega_l) 

263 d_h = cosmo.params.c / cosmo.params.H0 

264 d_m = cosmo.background.dist_trans_a(a=1 / (1 + z)) 

265 density = ( 

266 sm_phi_star 

267 * d_h 

268 * d_m**2 

269 / e 

270 * upper_inc_gamma(sm_alpha + 1, 10 ** (sm_min - sm_m_star)) 

271 ) 

272 

273 return density 

274 

275 

276class RedshiftSMSampler: 

277 """ 

278 Samples redshifts and stellar masses from the galaxy stellar mass function. 

279 

280 The sampling is done by first drawing redshifts from the redshift-pdf obtained by 

281 integrating out stellar masses. Then, we sample stellar masses from the 

282 conditional pdfs obtained by conditioning the stellar mass function on the sampled 

283 redshifts (the conditional pdf is different for each redshift). See also 

284 docs/jupyter_notebooks/sample_redshift_magnitude.ipynb and 

285 docs/jupyter_notebooks/test_self_consistency.ipynb. 

286 """ 

287 

288 def __init__( 

289 self, 

290 z_res, 

291 z_max, 

292 sm_res, 

293 sm_min, 

294 parametrization, 

295 sm_alpha, 

296 sm_m_star_par, 

297 sm_phi_star_par, 

298 cosmo, 

299 ): 

300 self.z_res = z_res 

301 self.z_max = z_max 

302 self.sm_res = sm_res 

303 self.sm_min = sm_min 

304 self.parametrization = parametrization 

305 self.sm_alpha = sm_alpha 

306 self.sm_m_star_par = sm_m_star_par 

307 self.sm_phi_star_par = sm_phi_star_par 

308 self.cosmo = cosmo 

309 

310 self._setup_redshift_grid( 

311 z_max, z_res, sm_min, sm_alpha, sm_m_star_par, sm_phi_star_par, cosmo 

312 ) 

313 self._setup_mag_grid( 

314 z_max, sm_min, sm_res, parametrization, sm_alpha, sm_m_star_par 

315 ) 

316 

317 def __call__(self, n_samples): 

318 z = np.random.choice(self.z_grid, size=n_samples, replace=True, p=self.nz_grid) 

319 

320 sm_m_s = stellar_mass_function_m_star( 

321 z, self.parametrization, *self.sm_m_star_par 

322 ) 

323 sm_m_rvs = np.random.uniform( 

324 low=0, 

325 high=upper_inc_gamma(self.sm_alpha + 1, 10 ** (self.sm_min - sm_m_s)), 

326 size=n_samples, 

327 ) # here, we sample M - M*, where M* is redshift-dependent 

328 uig_ind = find_closest_ind(self.uig_grid, sm_m_rvs) 

329 sm = sm_m_s + self.m_s__m__grid[uig_ind] # now we transform from M - M* to M 

330 return z, sm 

331 

332 def _setup_redshift_grid( 

333 self, z_max, z_res, sm_min, sm_alpha, sm_m_star_par, sm_phi_star_par, cosmo 

334 ): 

335 self.z_grid = np.linspace( 

336 z_res, z_max, num=int(round((z_max - z_res) / z_res)) + 1 

337 ) 

338 e = np.sqrt( 

339 cosmo.params.omega_m * (1 + self.z_grid) ** 3 + cosmo.params.omega_l 

340 ) 

341 d_h = cosmo.params.c / cosmo.params.H0 

342 d_m = cosmo.background.dist_trans_a(a=1 / (1 + self.z_grid)) 

343 f = d_h * d_m**2 / e 

344 

345 sm_m_star = stellar_mass_function_m_star( 

346 self.z_grid, self.parametrization, *sm_m_star_par 

347 ) 

348 sm_phi_star = stellar_mass_function_phi_star( 

349 self.z_grid, self.parametrization, *sm_phi_star_par 

350 ) 

351 

352 self.nz_grid = ( 

353 f * sm_phi_star * upper_inc_gamma(sm_alpha + 1, 10 ** (sm_min - sm_m_star)) 

354 ) 

355 self.nz_grid /= np.sum(self.nz_grid) 

356 

357 def _setup_mag_grid( 

358 self, z_max, sm_min, sm_res, parametrization, sm_alpha, sm_m_star_par 

359 ): 

360 sm_m_s_max = -scipy.optimize.minimize_scalar( 

361 lambda z: -stellar_mass_function_m_star(z, parametrization, *sm_m_star_par), 

362 bounds=(0, z_max), 

363 method="bounded", 

364 ).fun 

365 sm_m_s_min = scipy.optimize.minimize_scalar( 

366 lambda z: stellar_mass_function_m_star(z, parametrization, *sm_m_star_par), 

367 bounds=(0, z_max), 

368 method="bounded", 

369 ).fun 

370 

371 sm_max = sm_min 

372 while upper_inc_gamma(sm_alpha + 1, 10 ** (sm_max - sm_m_s_min)) > 0: 

373 sm_max += 0.1 

374 

375 self.m_s__m__grid = np.linspace( 

376 sm_max - sm_m_s_min, 

377 sm_min - sm_m_s_max, 

378 num=int(round((sm_max - sm_m_s_min - sm_min + sm_m_s_max) / sm_res)) + 1, 

379 ) 

380 self.uig_grid = upper_inc_gamma(sm_alpha + 1, 10**self.m_s__m__grid) 

381 

382 

383class StellarMassFunction: 

384 """ 

385 Stellar Mass function. 

386 """ 

387 

388 def __init__( 

389 self, 

390 name, 

391 sm_fct_parametrization, 

392 sm_m_star_0, 

393 sm_m_star_1, 

394 sm_m_star_2, 

395 sm_phi_star_amp, 

396 sm_phi_star_exp, 

397 sm_alpha, 

398 cosmo, 

399 pixarea, 

400 galaxy_type, 

401 seed_ngal, 

402 z_res=0.001, 

403 sm_res=0.01, 

404 z_max=3, 

405 sm_min=8, 

406 z_sm_intp=None, 

407 ngal_multiplier=1, 

408 ): 

409 self.parametrization = sm_fct_parametrization 

410 self.sm_m_star_0 = sm_m_star_0 

411 self.sm_m_star_1 = sm_m_star_1 

412 self.sm_m_star_2 = sm_m_star_2 

413 self.sm_phi_star_amp = sm_phi_star_amp 

414 self.sm_phi_star_exp = sm_phi_star_exp 

415 self.sm_m_star_par = sm_m_star_0, sm_m_star_1, sm_m_star_2 

416 self.sm_phi_star_par = sm_phi_star_amp, sm_phi_star_exp 

417 self.sm_alpha = sm_alpha 

418 self.z_sm_intp = z_sm_intp 

419 self.sm_min = sm_min 

420 self.cosmo = cosmo 

421 self.pixarea = pixarea 

422 self.z_res = z_res 

423 self.name = name 

424 self.galaxy_type = galaxy_type 

425 

426 self.z_max = maximum_redshift( 

427 z_sm_intp=z_sm_intp, 

428 sm_min=self.sm_min, 

429 z_max=z_max, 

430 parametrization=self.parametrization, 

431 alpha=self.sm_alpha, 

432 sm_m_star_par=self.sm_m_star_par, 

433 seed_ngal=seed_ngal, 

434 ) 

435 

436 self.n_gal_calc = NumGalCalculator( 

437 z_max=self.z_max, 

438 sm_min=sm_min, 

439 parametrization=sm_fct_parametrization, 

440 sm_alpha=sm_alpha, 

441 sm_m_star_par=self.sm_m_star_par, 

442 sm_phi_star_par=self.sm_phi_star_par, 

443 cosmo=cosmo, 

444 pixarea=pixarea, 

445 ngal_multiplier=ngal_multiplier, 

446 ) 

447 

448 self.z_sm_sampler = RedshiftSMSampler( 

449 z_res=z_res, 

450 z_max=self.z_max, 

451 parametrization=sm_fct_parametrization, 

452 sm_res=sm_res, 

453 sm_min=sm_min, 

454 sm_alpha=sm_alpha, 

455 sm_m_star_par=self.sm_m_star_par, 

456 sm_phi_star_par=self.sm_phi_star_par, 

457 cosmo=cosmo, 

458 ) 

459 

460 def sample_z_sm_and_apply_cut(self, seed_ngal, n_gal_max=np.inf, size_chunk=10000): 

461 """ 

462 This function gets the abs mag and z using chunking, which uses less memory 

463 than the original method. 

464 

465 It does not give exactly the same result as before due to different order of 

466 random draws in z_mabs_sampler, but it's the same sample. 

467 """ 

468 np.random.seed(seed_ngal) 

469 n_gal = self.n_gal_calc() 

470 if n_gal == 0: 

471 return np.array([]), np.array([]) 

472 n_chunks = int(np.ceil(float(n_gal) / float(size_chunk))) 

473 list_sm = [] 

474 list_z = [] 

475 n_gal_final = 0 

476 for ic in range(n_chunks): 

477 n_gal_sample = n_gal % size_chunk if ic + 1 == n_chunks else size_chunk 

478 z_chunk, sm_chunk = self.z_sm_sampler(n_gal_sample) 

479 

480 # Apply cut in z - M - plane 

481 if self.z_sm_intp is not None: 

482 select = sm_chunk < self.z_sm_intp(z_chunk) 

483 z_chunk = z_chunk[select] 

484 sm_chunk = sm_chunk[select] 

485 

486 n_gal_final += len(z_chunk) 

487 list_z.append(z_chunk) 

488 list_sm.append(sm_chunk) 

489 

490 if n_gal_final > n_gal_max: 

491 break 

492 

493 z = np.hstack(list_z) 

494 sm = np.hstack(list_sm) 

495 return sm, z