Coverage for src / galsbi / ucat_sps / plugins / sample_galaxies_sps.py: 94%

161 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-17 09:47 +0000

1# Copyright (C) 2025 LMU Munich 

2# Copyright (C) 2025 ETH Zurich, Institute for Particle Physics and Astrophysics 

3# Author: Luca Tortorelli, Silvan Fischbacher 

4# created: Apr 2025 

5 

6import warnings 

7 

8import healpy as hp 

9import numpy as np 

10import PyCosmo 

11from cosmic_toolbox import logger 

12from ivy.plugin.base_plugin import BasePlugin 

13from ufig import coordinate_util 

14 

15from galsbi.ucat import filters_util, galaxy_sampling_util, utils 

16from galsbi.ucat.galaxy_population_models.galaxy_position import sample_position_uniform 

17from galsbi.ucat.galaxy_population_models.galaxy_shape import ( 

18 sample_ellipticities_for_galaxy_type, 

19) 

20from galsbi.ucat.plugins.sample_galaxies_photo import ExtinctionMapEvaluator, in_pos 

21from galsbi.ucat_sps.galaxy_population_models.galaxy_agn import ( 

22 sample_ratio_agn_to_galaxy_bolometric_luminosity, 

23) 

24from galsbi.ucat_sps.galaxy_population_models.galaxy_dust_attenuation import ( 

25 sample_dust_attenuation_parameters, 

26) 

27from galsbi.ucat_sps.galaxy_population_models.galaxy_dust_emission import ( 

28 sample_Dale2014_dust_emission_parameters, 

29) 

30from galsbi.ucat_sps.galaxy_population_models.galaxy_gas_ionization import ( 

31 sample_gas_ionization_Kashino19, 

32) 

33from galsbi.ucat_sps.galaxy_population_models.galaxy_light_profile import ( 

34 sample_sersic_for_galaxy_type, 

35) 

36from galsbi.ucat_sps.galaxy_population_models.galaxy_metallicity_history import ( 

37 generate_gas_metallicity_history_snorm_trunc, 

38 sample_gas_metallicity, 

39) 

40from galsbi.ucat_sps.galaxy_population_models.galaxy_size import ( 

41 sample_r50_from_stellar_mass, 

42) 

43from galsbi.ucat_sps.galaxy_population_models.galaxy_star_formation_history import ( 

44 compute_surviving_stellar_mass, 

45 sample_sfh_snorm_trunc, 

46) 

47from galsbi.ucat_sps.galaxy_population_models.galaxy_stellar_mass_function import ( 

48 initialize_stellar_mass_functions, 

49) 

50from galsbi.ucat_sps.galaxy_population_models.galaxy_velocity_dispersion import ( 

51 sample_velocity_dispersion_zahid, 

52) 

53from galsbi.ucat_sps.sps_sed_generator import MagnitudeGenerator 

54 

55LOGGER = logger.get_logger(__file__) 

56warnings.filterwarnings("once") 

57 

58 

59def check_max_mem_error(max_mem_hard_limit=10000): 

60 """ 

61 Check if the catalog does not exceed allowed memory. 

62 

63 Prevents job crashes on clusters. 

64 """ 

65 

66 mem_mb_current = utils.memory_usage_psutil() 

67 if mem_mb_current > max_mem_hard_limit: 

68 raise galaxy_sampling_util.UCatNumGalError( 

69 "The sample_galaxies process is taking too much memory: " 

70 f"mem_mb_current={mem_mb_current}, max_mem_hard_limit={max_mem_hard_limit}" 

71 ) 

72 

73 

74class Plugin(BasePlugin): 

75 """ 

76 Generate a catalog of galaxies with SPS computed magnitudes in multiple bands. 

77 """ 

78 

79 def __call__(self): 

80 par = self.ctx.parameters 

81 

82 # Cosmology 

83 cosmo = PyCosmo.build() 

84 cosmo.set(h=par.h, omega_m=par.omega_m) 

85 

86 # Healpix pixelization 

87 if par.sampling_mode == "wcs": 

88 LOGGER.info("Sampling galaxies based on RA/DEC and pixel scale") 

89 # Healpix pixelization 

90 w = coordinate_util.wcs_from_parameters(par) 

91 self.ctx.pixels = coordinate_util.get_healpix_pixels( 

92 par.nside_sampling, w, par.size_x, par.size_y 

93 ) 

94 if len(self.ctx.pixels) < 15: 

95 LOGGER.warning( 

96 f"Only {len(self.ctx.pixels)} healpy pixels in the footprint," 

97 " consider increasing the nside_sampling" 

98 ) 

99 elif par.sampling_mode == "healpix": 

100 LOGGER.info("Sampling galaxies based on healpix pixels") 

101 self.ctx.pixels = coordinate_util.get_healpix_pixels_from_map(par) 

102 w = None 

103 else: 

104 raise ValueError( 

105 f"Unknown sampling mode: {par.sampling_mode}, must be wcs or healpix" 

106 ) 

107 self.ctx.pixarea = hp.nside2pixarea(par.nside_sampling, degrees=False) 

108 

109 # Magnitude calculator 

110 all_filters = np.unique(par.filters + [par.lum_fct_filter_band]) 

111 

112 # backward compatibility - check if full filter names are set 

113 if not hasattr(par, "filters_full_names"): 

114 par.filters_full_names = filters_util.get_default_full_filter_names( 

115 all_filters 

116 ) 

117 warnings.warn( 

118 "setting filters to default, this will cause problems if you work" 

119 " with filters from different cameras in the same band", 

120 stacklevel=1, 

121 ) 

122 

123 # initialize stellar mass functions 

124 sm_funcs_lowmass = initialize_stellar_mass_functions( 

125 par, 

126 cosmo=cosmo, 

127 pixarea=self.ctx.pixarea, 

128 z_sm_intp=None, 

129 mass_key="lowmass", # z_m_intp 

130 ) 

131 

132 sm_funcs_highmass = initialize_stellar_mass_functions( 

133 par, 

134 cosmo=cosmo, 

135 pixarea=self.ctx.pixarea, 

136 z_sm_intp=None, 

137 mass_key="highmass", # z_m_intp 

138 ) 

139 

140 # setup SPS SED generator 

141 sps_sed_generator = MagnitudeGenerator(par) 

142 

143 # Extinction 

144 exctinction_eval = ExtinctionMapEvaluator(par) 

145 

146 # Initialize galaxy catalog 

147 self.ctx.galaxies = galaxy_sampling_util.Catalog() 

148 self.ctx.galaxies.columns = [ 

149 "id", 

150 "z", 

151 "log10_stellar_mass", 

152 "log10_surviv_stellar_mass", 

153 "star_formation_history", 

154 "star_formation_rate", 

155 "mSFR", 

156 "mpeak", 

157 "mperiod", 

158 "mskew", 

159 "gas_metallicity", 

160 "metallicity_history", 

161 "gas_ionization", 

162 "dust_attenuation", 

163 "dust_emission", 

164 "int_r50", 

165 "r50_phys", 

166 "int_e1", 

167 "int_e2", 

168 "sersic_n", 

169 "sigma_vel", 

170 "fagn", 

171 # "mag", 

172 "galaxy_type", 

173 "excess_b_v", 

174 ] 

175 

176 # Columns modified inside loop 

177 loop_cols = [ 

178 "z", 

179 "log10_stellar_mass", 

180 "log10_surviv_stellar_mass", 

181 "star_formation_history", 

182 "star_formation_rate", 

183 "mSFR", 

184 "mpeak", 

185 "mperiod", 

186 "mskew", 

187 "gas_metallicity", 

188 "metallicity_history", 

189 "gas_ionization", 

190 "dust_attenuation", 

191 "dust_emission", 

192 "int_r50", 

193 "r50_phys", 

194 "int_e1", 

195 "int_e2", 

196 "sersic_n", 

197 "sigma_vel", 

198 "fagn", 

199 # "mag", 

200 "galaxy_type", 

201 "excess_b_v", 

202 ] 

203 if w is not None: 

204 loop_cols += ["x", "y"] 

205 self.ctx.galaxies.columns += ["x", "y"] 

206 else: 

207 loop_cols += ["ra", "dec"] 

208 self.ctx.galaxies.columns += ["ra", "dec"] 

209 for c in loop_cols: 

210 setattr(self.ctx.galaxies, c, []) 

211 

212 self.ctx.galaxies.int_magnitude_dict = dict() 

213 self.ctx.galaxies.abs_magnitude_dict = dict() 

214 for key in par.filters: 

215 self.ctx.galaxies.int_magnitude_dict[key] = [] 

216 self.ctx.galaxies.abs_magnitude_dict[key] = [] 

217 # loop over stellar mass functions 

218 for g in par.galaxy_types: 

219 n_gal_type = 0 

220 n_gal_type_max = getattr(par, f"n_gal_max_{g}") 

221 max_reached = False 

222 

223 for i in LOGGER.progressbar( 

224 range(len(self.ctx.pixels)), 

225 desc=f"getting {g:<4s} galaxies for healpix pixels", 

226 at_level="debug", 

227 ): 

228 # Sample stellar masses and redshifts from double stellar mass function 

229 sm_lowmass, z_lowmass = sm_funcs_lowmass[g].sample_z_sm_and_apply_cut( 

230 seed_ngal=par.seed 

231 + self.ctx.pixels[i] 

232 + par.gal_num_seed_offset 

233 + par.SEED_OFFSET_SMFUN * i, 

234 n_gal_max=n_gal_type_max, 

235 ) 

236 

237 sm_highmass, z_highmass = sm_funcs_highmass[ 

238 g 

239 ].sample_z_sm_and_apply_cut( 

240 seed_ngal=par.seed 

241 + self.ctx.pixels[i] 

242 + par.gal_num_seed_offset 

243 + par.SEED_OFFSET_SMFUN * i, 

244 n_gal_max=n_gal_type_max, 

245 ) 

246 

247 log10_stellar_masses = np.concatenate( 

248 (sm_lowmass, sm_highmass) 

249 ) # log10(Mass/Msun) 

250 z = np.concatenate((z_lowmass, z_highmass)) 

251 

252 ngal = len(log10_stellar_masses) 

253 

254 # star formation history 

255 ( 

256 sfhs_msun_yr, 

257 sfrs_msun_yr, 

258 mSFR, 

259 mpeak, 

260 mperiod, 

261 mskew, 

262 ) = sample_sfh_snorm_trunc( 

263 log10_stellar_masses, 

264 z, 

265 par, 

266 cosmo, 

267 galaxy_type=g, 

268 seed=par.seed 

269 + self.ctx.pixels[i] 

270 + par.gal_num_seed_offset 

271 + par.SEED_OFFSET_SFH * i, 

272 ) 

273 

274 # gas metallicity history 

275 Zgas_final = sample_gas_metallicity( 

276 log10_stellar_masses, 

277 z, 

278 sfrs_msun_yr, 

279 par, 

280 cosmo, 

281 galaxy_type=g, 

282 seed=par.seed 

283 + self.ctx.pixels[i] 

284 + par.gal_num_seed_offset 

285 + par.SEED_OFFSET_ZGAS * i, 

286 ) 

287 Z_gas = generate_gas_metallicity_history_snorm_trunc( 

288 Zgas_final, 

289 sfhs_msun_yr, 

290 par, 

291 Zgas_init=par.Zgas_init, 

292 ) 

293 

294 # surviving stellar mass 

295 log10_surviv_stellar_masses = compute_surviving_stellar_mass( 

296 z, 

297 log10_stellar_masses, 

298 mSFR, 

299 mpeak, 

300 mperiod, 

301 mskew, 

302 Zgas_final, 

303 par, 

304 ) 

305 

306 # gas ionization 

307 logU = sample_gas_ionization_Kashino19( 

308 Zgas_final, 

309 sfrs_msun_yr, 

310 log10_stellar_masses, 

311 par, 

312 seed=par.seed 

313 + self.ctx.pixels[i] 

314 + par.gal_num_seed_offset 

315 + par.SEED_OFFSET_LOGU * i, 

316 ) 

317 

318 # dust attenuation (internal to the observed galaxy) 

319 log_specific_star_formation_rate = np.log10( 

320 sfrs_msun_yr / 10**log10_stellar_masses 

321 ) 

322 dust_attenuation_params = sample_dust_attenuation_parameters( 

323 par, 

324 log10_stellar_masses, 

325 log_specific_star_formation_rate, 

326 z, 

327 galaxy_type=g, 

328 seed=par.seed 

329 + self.ctx.pixels[i] 

330 + par.gal_num_seed_offset 

331 + par.SEED_OFFSET_DUSTATT * i, 

332 ) 

333 # dust emission 

334 dust_emission_params = sample_Dale2014_dust_emission_parameters( 

335 par, 

336 len(log10_stellar_masses), 

337 galaxy_type=g, 

338 seed=par.seed 

339 + self.ctx.pixels[i] 

340 + par.gal_num_seed_offset 

341 + par.SEED_OFFSET_DUSTEM * i, 

342 ) 

343 

344 # intrinsic size z 

345 int_r50, int_r50_arcsec, r50_phys = sample_r50_from_stellar_mass( 

346 z, 

347 log10_stellar_masses, 

348 cosmo, 

349 par, 

350 galaxy_type=g, 

351 seed=par.seed 

352 + self.ctx.pixels[i] 

353 + par.gal_num_seed_offset 

354 + par.SEED_OFFSET_SIZE * i, 

355 ) 

356 

357 # intrinsic ellipticity 

358 np.random.seed( 

359 par.seed + self.ctx.pixels[i] + par.gal_ellipticities_seed_offset 

360 ) 

361 ( 

362 int_e1, 

363 int_e2, 

364 ) = sample_ellipticities_for_galaxy_type( 

365 n_gal=ngal, galaxy_type=g, par=par 

366 ) 

367 

368 # sersic index 

369 np.random.seed( 

370 par.seed + self.ctx.pixels[i] + par.gal_sersic_seed_offset 

371 ) 

372 sersic_n = sample_sersic_for_galaxy_type( 

373 n_gal=ngal, 

374 galaxy_type=g, 

375 par=par, 

376 log10_stellar_mass=log10_stellar_masses, 

377 ) 

378 

379 # assign velocity dispersion 

380 vel_disp = sample_velocity_dispersion_zahid( 

381 par, 

382 log10_stellar_masses, 

383 seed=par.seed 

384 + self.ctx.pixels[i] 

385 + par.gal_num_seed_offset 

386 + par.SEED_OFFSET_VELDISP * i, 

387 ) 

388 

389 # sample agn luminosity 

390 fagn = sample_ratio_agn_to_galaxy_bolometric_luminosity( 

391 par, 

392 z, 

393 log10_stellar_masses, 

394 galaxy_type=g, 

395 seed=par.seed 

396 + self.ctx.pixels[i] 

397 + par.gal_num_seed_offset 

398 + par.SEED_OFFSET_AGN * i, 

399 ) 

400 

401 # Positions 

402 np.random.seed(par.seed + self.ctx.pixels[i] + par.gal_dist_seed_offset) 

403 # x and y for wcs, ra and dec for healpix model 

404 x, y = sample_position_uniform( 

405 len(z), w, self.ctx.pixels[i], par.nside_sampling 

406 ) 

407 # Make the catalog precision already here to avoid inconsistencies in 

408 # the selection 

409 x = x.astype(par.catalog_precision) 

410 y = y.astype(par.catalog_precision) 

411 

412 # apply MW Extinction 

413 excess_b_v = exctinction_eval(w, x, y) 

414 if len(z) == 1: 

415 excess_b_v = np.array([excess_b_v]) 

416 

417 # select galaxies in image boundaries 

418 if w is not None: 

419 select = in_pos(x, y, par) 

420 else: 

421 select = np.ones_like(x, dtype=bool) 

422 n_gal = np.count_nonzero(select) 

423 n_gal_type += n_gal 

424 

425 # store 

426 if w is not None: 

427 self.ctx.galaxies.x.append(x[select].astype(par.catalog_precision)) 

428 self.ctx.galaxies.y.append(y[select].astype(par.catalog_precision)) 

429 else: 

430 self.ctx.galaxies.ra.append(x[select].astype(par.catalog_precision)) 

431 self.ctx.galaxies.dec.append( 

432 y[select].astype(par.catalog_precision) 

433 ) 

434 

435 self.ctx.galaxies.z.append(z[select].astype(par.catalog_precision)) 

436 

437 self.ctx.galaxies.galaxy_type.append( 

438 np.ones(n_gal, dtype=np.ushort) * sm_funcs_lowmass[g].galaxy_type 

439 ) 

440 self.ctx.galaxies.excess_b_v.append( 

441 excess_b_v[select].astype(par.catalog_precision) 

442 ) 

443 self.ctx.galaxies.log10_stellar_mass.append( 

444 log10_stellar_masses[select].astype(par.catalog_precision) 

445 ) 

446 self.ctx.galaxies.star_formation_history.append( 

447 np.array(sfhs_msun_yr)[select] 

448 ) 

449 self.ctx.galaxies.star_formation_rate.append(sfrs_msun_yr[select]) 

450 self.ctx.galaxies.mSFR.append(mSFR[select]) 

451 self.ctx.galaxies.mpeak.append(mpeak[select]) 

452 self.ctx.galaxies.mperiod.append(mperiod[select]) 

453 self.ctx.galaxies.mskew.append(mskew[select]) 

454 self.ctx.galaxies.gas_metallicity.append( 

455 np.array(Zgas_final)[select].astype(par.catalog_precision) 

456 ) 

457 self.ctx.galaxies.metallicity_history.append(Z_gas[select]) 

458 self.ctx.galaxies.log10_surviv_stellar_mass.append( 

459 log10_surviv_stellar_masses[select] 

460 ) 

461 self.ctx.galaxies.gas_ionization.append( 

462 np.array(logU)[select].astype(par.catalog_precision) 

463 ) 

464 self.ctx.galaxies.dust_attenuation.append( 

465 np.array(dust_attenuation_params)[select].astype( 

466 par.catalog_precision 

467 ) 

468 ) 

469 self.ctx.galaxies.dust_emission.append( 

470 np.array(dust_emission_params)[select].astype(par.catalog_precision) 

471 ) 

472 self.ctx.galaxies.int_r50.append( 

473 np.array(int_r50)[select].astype(par.catalog_precision) 

474 ) 

475 self.ctx.galaxies.r50_phys.append( 

476 np.array(r50_phys)[select].astype(par.catalog_precision) 

477 ) 

478 self.ctx.galaxies.int_e1.append( 

479 np.array(int_e1)[select].astype(par.catalog_precision) 

480 ) 

481 self.ctx.galaxies.int_e2.append( 

482 np.array(int_e2)[select].astype(par.catalog_precision) 

483 ) 

484 self.ctx.galaxies.sersic_n.append( 

485 np.array(sersic_n)[select].astype(par.catalog_precision) 

486 ) 

487 self.ctx.galaxies.sigma_vel.append( 

488 np.array(vel_disp)[select].astype(par.catalog_precision) 

489 ) 

490 self.ctx.galaxies.fagn.append( 

491 np.array(fagn)[select].astype(par.catalog_precision) 

492 ) 

493 

494 # check memory footprint 

495 check_max_mem_error() 

496 

497 # see if number of galaxies is OK 

498 if n_gal_type > n_gal_type_max * par.ngal_multiplier: 

499 max_reached = True 

500 if par.raise_max_num_gal_error: 

501 raise galaxy_sampling_util.UCatNumGalError( 

502 f"exceeded number of {g} galaxies " 

503 "{n_gal_type}>{n_gal_type_max}" 

504 ) 

505 else: 

506 break 

507 

508 LOGGER.info( 

509 f"smfun={g} n_gals={n_gal_type} maximum number of galaxies " 

510 f"reached={max_reached} ({n_gal_type_max})" 

511 ) 

512 

513 # Concatenate columns 

514 for c in loop_cols: 

515 setattr(self.ctx.galaxies, c, np.concatenate(getattr(self.ctx.galaxies, c))) 

516 

517 LOGGER.info(f"Computing magnitudes with {par.sed_generator}") 

518 output_sed_generator = sps_sed_generator( 

519 par, 

520 z=self.ctx.galaxies.z, 

521 mSFR=self.ctx.galaxies.mSFR, 

522 mpeak=self.ctx.galaxies.mpeak, 

523 mperiod=self.ctx.galaxies.mperiod, 

524 mskew=self.ctx.galaxies.mskew, 

525 Zgas_final=self.ctx.galaxies.gas_metallicity, 

526 logU=self.ctx.galaxies.gas_ionization, 

527 fagn=self.ctx.galaxies.fagn, 

528 dust_attenuation_params=self.ctx.galaxies.dust_attenuation, 

529 excess_b_v=self.ctx.galaxies.excess_b_v, 

530 dust_emission_params=self.ctx.galaxies.dust_emission, 

531 vel_disp=self.ctx.galaxies.sigma_vel, 

532 ssp_library_filepath=par.ssp_library_filepath, 

533 filter_names=par.filters_full_names_prospect, 

534 cosmology=cosmo, 

535 log10_stellar_mass=self.ctx.galaxies.log10_stellar_mass, 

536 ) 

537 

538 for j, key in enumerate(par.filters): 

539 self.ctx.galaxies.int_magnitude_dict[key] = output_sed_generator[1][ 

540 :, j 

541 ] # maybe save mags before and after mw extinction 

542 self.ctx.galaxies.abs_magnitude_dict[key] = output_sed_generator[0][:, j] 

543 

544 # Set apparent (lensed) magnitudes, for now equal to intrinsic magnitudes 

545 self.ctx.galaxies.magnitude_dict = dict() 

546 for band, mag in self.ctx.galaxies.int_magnitude_dict.items(): 

547 self.ctx.galaxies.magnitude_dict[band] = mag.copy() 

548 

549 # Number of galaxies and id 

550 self.ctx.numgalaxies = self.ctx.galaxies.z.size 

551 self.ctx.galaxies.id = np.arange(self.ctx.numgalaxies) 

552 

553 # Backward compatibility 

554 self.ctx.galaxies.blue_red = np.ones(len(self.ctx.galaxies.z), dtype=np.ushort) 

555 self.ctx.galaxies.blue_red[ 

556 self.ctx.galaxies.galaxy_type == sm_funcs_lowmass["blue"].galaxy_type 

557 ] = 1 

558 self.ctx.galaxies.blue_red[ 

559 self.ctx.galaxies.galaxy_type == sm_funcs_lowmass["red"].galaxy_type 

560 ] = 0 

561 

562 if par.save_SEDs & (par.sed_generator.lower() == "prospect"): 

563 # Store SEDs in the catalog 

564 self.ctx.restframe_wavelength_for_SED = output_sed_generator[2] 

565 self.ctx.galaxies.sed = output_sed_generator[3] 

566 

567 LOGGER.info( 

568 f"galaxy counts n_total={self.ctx.numgalaxies} " 

569 "mem_mb_current={utils.memory_usage_psutil():5.1f}" 

570 ) 

571 

572 def __str__(self): 

573 return "sample gal photo prospect"