Coverage for src/galsbi/ucat/plugins/sample_galaxies_photo.py: 98%

183 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-08-26 10:41 +0000

1# Copyright (C) 2018 ETH Zurich, Institute for Particle Physics and Astrophysics 

2 

3""" 

4Created on Mar 5, 2018 

5author: Joerg Herbel 

6""" 

7 

8import os 

9import warnings 

10 

11import healpy as hp 

12import numpy as np 

13import PyCosmo 

14from astropy.coordinates import SkyCoord 

15from cosmic_toolbox import logger 

16from ivy.plugin.base_plugin import BasePlugin 

17from ufig import coordinate_util, io_util 

18 

19from galsbi.ucat import ( 

20 filters_util, 

21 galaxy_sampling_util, 

22 sed_templates_util, 

23 spectrum_util, 

24 utils, 

25) 

26from galsbi.ucat.filters_util import UseShortFilterNames 

27from galsbi.ucat.galaxy_population_models.galaxy_luminosity_function import ( 

28 initialize_luminosity_functions, 

29) 

30from galsbi.ucat.galaxy_population_models.galaxy_position import sample_position_uniform 

31from galsbi.ucat.galaxy_population_models.galaxy_sed import ( 

32 sample_template_coeff_lumfuncs, 

33) 

34 

35LOGGER = logger.get_logger(__file__) 

36SEED_OFFSET_LUMFUN = 123491 

37warnings.filterwarnings("once") 

38 

39 

40class ExtinctionMapEvaluator: 

41 """ 

42 Class that gives extinction values for positions 

43 """ 

44 

45 def __init__(self, par): 

46 if par.extinction_map_file_name is not None: 

47 extinction_map_file_name = io_util.get_abs_path( 

48 par.extinction_map_file_name, root_path=par.maps_remote_dir 

49 ) 

50 self.extinction_map = hp.read_map( 

51 extinction_map_file_name, nest=True, field=0 

52 ) 

53 

54 else: 

55 self.extinction_map = None 

56 

57 def __call__(self, wcs, x, y): 

58 if self.extinction_map is not None: 

59 if wcs is not None: 

60 ra, dec = coordinate_util.xy2radec(wcs, x, y) 

61 else: 

62 ra, dec = x, y 

63 sky_coord = SkyCoord(ra=ra, dec=dec, frame="icrs", unit="deg") 

64 gal_lon = sky_coord.galactic.l.deg 

65 gal_lat = sky_coord.galactic.b.deg 

66 theta, phi = coordinate_util.radec2thetaphi(gal_lon, gal_lat) 

67 excess_b_v = hp.get_interp_val( 

68 self.extinction_map, theta=theta, phi=phi, nest=True 

69 ) 

70 

71 else: 

72 excess_b_v = np.zeros_like(x) 

73 

74 return excess_b_v 

75 

76 

77def get_magnitude_calculator_direct(filter_names, par): 

78 """ 

79 Interface to direct magnitude calculation 

80 """ 

81 

82 filter_names_full = [par.filters_full_names[f] for f in filter_names] 

83 # TODO: par should be the path to the filters file, either change it here or there 

84 filepath_sed_integ = os.path.join(par.maps_remote_dir, par.filters_file_name) 

85 filters = filters_util.load_filters( 

86 filepath_sed_integ, filter_names=filter_names_full, lam_scale=1e-4 

87 ) 

88 

89 filepath_sed_templates = os.path.join(par.maps_remote_dir, par.templates_file_name) 

90 # Load SED templates with following units: 

91 # Lambda: micrometer 

92 # SED: erg/s/m2/Å 

93 sed_templates = sed_templates_util.load_template_spectra( 

94 filepath_sed_templates, lam_scale=1e-4, amp_scale=1e4 

95 ) 

96 

97 from galsbi.ucat.magnitude_calculator import MagCalculatorDirect 

98 

99 return UseShortFilterNames( 

100 MagCalculatorDirect(filters, sed_templates), par.filters_full_names 

101 ) 

102 

103 

104def get_magnitude_calculator_table(filter_names, par): 

105 """ 

106 Interface to magnitude calculation with pre-computed tables 

107 """ 

108 

109 filepath_sed_integ = os.path.join( 

110 par.maps_remote_dir, par.templates_int_tables_file_name 

111 ) 

112 filter_full_names = [par.filters_full_names[f] for f in filter_names] 

113 

114 from galsbi.ucat.magnitude_calculator import MagCalculatorTable 

115 

116 return UseShortFilterNames( 

117 MagCalculatorTable( 

118 filter_full_names, 

119 filepath_sed_integ, 

120 copy_to_cwd=par.copy_template_int_tables_to_cwd, 

121 ), 

122 par.filters_full_names, 

123 ) 

124 

125 

126MAGNITUDES_CALCULATOR = { 

127 "direct": get_magnitude_calculator_direct, 

128 "table": get_magnitude_calculator_table, 

129} 

130 

131 

132class Plugin(BasePlugin): 

133 """ 

134 Generate a random catalog of galaxies with magnitudes in multiple bands. 

135 """ 

136 

137 def check_n_gal_prior(self, par): 

138 """ 

139 Check if the number of galaxies is inside the prior range, 

140 even before rendering, to remove extreme values. 

141 """ 

142 

143 if hasattr(par, "galaxy_count_prior"): 

144 app_mag = self.ctx.galaxies.int_magnitude_dict[ 

145 par.galaxy_count_prior["band"] 

146 ] 

147 n_gal_ = np.count_nonzero(app_mag < par.galaxy_count_prior["mag_max"]) 

148 n_gal_ /= par.ngal_multiplier 

149 LOGGER.info( 

150 ( 

151 "Allowed number galaxies with int_mag<{} per tile [{},{}]," 

152 " computed number: {}".format 

153 )( 

154 par.galaxy_count_prior["mag_max"], 

155 par.galaxy_count_prior["n_min"], 

156 par.galaxy_count_prior["n_max"], 

157 n_gal_, 

158 ) 

159 ) 

160 if (n_gal_ < par.galaxy_count_prior["n_min"]) or ( 

161 n_gal_ > par.galaxy_count_prior["n_max"] 

162 ): 

163 raise galaxy_sampling_util.UCatNumGalError( 

164 "too many or too few galaxies" 

165 ) 

166 

167 def check_max_mem_error(self, par): 

168 """ 

169 Check if the catalog does not exceed allowed memory. 

170 Prevents job crashes on clusters. 

171 """ 

172 

173 mem_mb_current = utils.memory_usage_psutil() 

174 if mem_mb_current > par.max_memlimit_gal_catalog: 

175 raise galaxy_sampling_util.UCatNumGalError( 

176 "The sample_galaxies process is taking too much memory:" 

177 f" mem_mb_current={mem_mb_current}," 

178 f" max_mem_hard_limit={par.max_memlimit_gal_catalog}" 

179 ) 

180 

181 def __call__(self): 

182 par = self.ctx.parameters 

183 

184 # Cosmology 

185 cosmo = PyCosmo.build() 

186 cosmo.set(h=par.h, omega_m=par.omega_m) 

187 

188 if par.sampling_mode == "wcs": 

189 LOGGER.info("Sampling galaxies based on RA/DEC and pixel scale") 

190 # Healpix pixelization 

191 w = coordinate_util.wcs_from_parameters(par) 

192 self.ctx.pixels = coordinate_util.get_healpix_pixels( 

193 par.nside_sampling, w, par.size_x, par.size_y 

194 ) 

195 if len(self.ctx.pixels) < 15: 

196 LOGGER.warning( 

197 f"Only {len(self.ctx.pixels)} healpy pixels in the footprint," 

198 " consider increasing the nside_sampling" 

199 ) 

200 elif par.sampling_mode == "healpix": 

201 LOGGER.info("Sampling galaxies based on healpix pixels") 

202 self.ctx.pixels = coordinate_util.get_healpix_pixels_from_map(par) 

203 w = None 

204 else: 

205 raise ValueError( 

206 f"Unknown sampling mode: {par.sampling_mode}, must be wcs or healpix" 

207 ) 

208 self.ctx.pixarea = hp.nside2pixarea(par.nside_sampling, degrees=False) 

209 

210 # Magnitude calculator 

211 all_filters = np.unique(par.filters + [par.lum_fct_filter_band]) 

212 

213 # backward compatibility - check if full filter names are set 

214 if not hasattr(par, "filters_full_names"): 

215 par.filters_full_names = filters_util.get_default_full_filter_names( 

216 all_filters 

217 ) 

218 warnings.warn( 

219 "setting filters to default, this will cause problems if you work" 

220 " with filters from different cameras in the same band", 

221 stacklevel=1, 

222 ) 

223 

224 # get magnitude calculators (reload cache to avoid memory leaks and save memory) 

225 mag_calc = MAGNITUDES_CALCULATOR[par.magnitude_calculation]( 

226 filter_names=all_filters, par=par 

227 ) 

228 n_templates = mag_calc.n_templates 

229 # Cut in z - M - plane & boundaries 

230 z_m_intp = galaxy_sampling_util.intp_z_m_cut(cosmo, mag_calc, par) 

231 

232 # Initialize galaxy catalog 

233 self.ctx.galaxies = galaxy_sampling_util.Catalog() 

234 self.ctx.galaxies.columns = [ 

235 "id", 

236 "z", 

237 "template_coeffs", 

238 "template_coeffs_abs", 

239 "abs_mag_lumfun", 

240 "galaxy_type", 

241 "excess_b_v", 

242 ] 

243 

244 # Columns modified inside loop 

245 loop_cols = [ 

246 "z", 

247 "template_coeffs", 

248 "template_coeffs_abs", 

249 "abs_mag_lumfun", 

250 "galaxy_type", 

251 "excess_b_v", 

252 ] 

253 if w is not None: 

254 loop_cols += ["x", "y"] 

255 self.ctx.galaxies.columns += ["x", "y"] 

256 else: 

257 loop_cols += ["ra", "dec"] 

258 self.ctx.galaxies.columns += ["ra", "dec"] 

259 for c in loop_cols: 

260 setattr(self.ctx.galaxies, c, []) 

261 

262 # set up luminosity functions 

263 lum_funcs = initialize_luminosity_functions( 

264 par, cosmo=cosmo, pixarea=self.ctx.pixarea, z_m_intp=z_m_intp 

265 ) 

266 

267 # Extinction 

268 extinction_eval = ExtinctionMapEvaluator(par) 

269 

270 # Helper function to compute templates, extinction 

271 for g in par.galaxy_types: 

272 n_gal_type = 0 

273 n_gal_type_max = getattr(par, f"n_gal_max_{g}") 

274 max_reached = False 

275 

276 for i in LOGGER.progressbar( 

277 range(len(self.ctx.pixels)), 

278 desc=f"getting {g:<4s} galaxies for healpix pixels", 

279 at_level="debug", 

280 ): 

281 # Sample absolute mag vs redshift from luminosity function 

282 abs_mag, z = lum_funcs[g].sample_z_mabs_and_apply_cut( 

283 seed_ngal=par.seed 

284 + self.ctx.pixels[i] 

285 + par.gal_num_seed_offset 

286 + SEED_OFFSET_LUMFUN * i, 

287 seed_lumfun=par.seed 

288 + self.ctx.pixels[i] 

289 + par.gal_lum_fct_seed_offset 

290 + SEED_OFFSET_LUMFUN * i, 

291 n_gal_max=n_gal_type_max, 

292 ) 

293 

294 # Positions 

295 np.random.seed(par.seed + self.ctx.pixels[i] + par.gal_dist_seed_offset) 

296 

297 # x and y for wcs, ra and dec for healpix model 

298 x, y = sample_position_uniform( 

299 len(z), w, self.ctx.pixels[i], par.nside_sampling 

300 ) 

301 

302 # Make the catalog precision already here to avoid inconsistencies in 

303 # the selection 

304 x = x.astype(par.catalog_precision) 

305 y = y.astype(par.catalog_precision) 

306 

307 ( 

308 template_coeffs, 

309 template_coeffs_abs, 

310 excess_b_v, 

311 app_mag_ref, 

312 ) = compute_templates_extinction_appmag_for_galaxies( 

313 galaxy_type=g, 

314 par=par, 

315 n_templates=n_templates, 

316 cosmo=cosmo, 

317 w=w, 

318 redshifts=z, 

319 absmags=abs_mag, 

320 x_pixel=x, 

321 y_pixel=y, 

322 mag_calc=mag_calc, 

323 extinction_eval=extinction_eval, 

324 ) 

325 

326 # Reject galaxies outside set magnitude range 

327 select_mag_range = (app_mag_ref >= par.gals_mag_min) & ( 

328 app_mag_ref <= par.gals_mag_max 

329 ) 

330 if w is not None: 

331 select_pos_range = in_pos(x, y, par) 

332 else: 

333 select_pos_range = np.ones_like(x, dtype=bool) 

334 select = select_mag_range & select_pos_range 

335 n_gal = np.count_nonzero(select) 

336 n_gal_type += n_gal 

337 

338 # store 

339 if w is not None: 

340 self.ctx.galaxies.x.append(x[select].astype(par.catalog_precision)) 

341 self.ctx.galaxies.y.append(y[select].astype(par.catalog_precision)) 

342 else: 

343 self.ctx.galaxies.ra.append(x[select].astype(par.catalog_precision)) 

344 self.ctx.galaxies.dec.append( 

345 y[select].astype(par.catalog_precision) 

346 ) 

347 self.ctx.galaxies.z.append(z[select].astype(par.catalog_precision)) 

348 self.ctx.galaxies.template_coeffs.append( 

349 template_coeffs[select].astype(par.catalog_precision) 

350 ) 

351 self.ctx.galaxies.template_coeffs_abs.append( 

352 template_coeffs_abs[select].astype(par.catalog_precision) 

353 ) 

354 self.ctx.galaxies.abs_mag_lumfun.append( 

355 abs_mag[select].astype(par.catalog_precision) 

356 ) 

357 self.ctx.galaxies.galaxy_type.append( 

358 np.ones(n_gal, dtype=np.ushort) * lum_funcs[g].galaxy_type 

359 ) 

360 self.ctx.galaxies.excess_b_v.append( 

361 excess_b_v[select].astype(par.catalog_precision) 

362 ) 

363 

364 # see if number of galaxies is OK 

365 if n_gal_type > n_gal_type_max * par.ngal_multiplier: 

366 max_reached = True 

367 if par.raise_max_num_gal_error: 

368 raise galaxy_sampling_util.UCatNumGalError( 

369 "exceeded number of" 

370 f" {g} galaxies {n_gal_type}>{n_gal_type_max}" 

371 ) 

372 else: 

373 break 

374 

375 LOGGER.info( 

376 f"lumfun={g} n_gals={n_gal_type} maximum number of galaxies" 

377 f" reached={max_reached} ({n_gal_type_max})" 

378 ) 

379 

380 # check memory footprint 

381 self.check_max_mem_error(par) 

382 

383 # Concatenate columns 

384 for c in loop_cols: 

385 setattr(self.ctx.galaxies, c, np.concatenate(getattr(self.ctx.galaxies, c))) 

386 

387 # Calculate requested intrinsic apparent and absolute magnitudes 

388 self.ctx.galaxies.int_magnitude_dict = mag_calc( 

389 redshifts=self.ctx.galaxies.z, 

390 excess_b_v=self.ctx.galaxies.excess_b_v, 

391 coeffs=self.ctx.galaxies.template_coeffs, 

392 filter_names=par.filters, 

393 ) 

394 

395 self.ctx.galaxies.abs_magnitude_dict = mag_calc( 

396 redshifts=np.zeros_like(self.ctx.galaxies.z), 

397 excess_b_v=np.zeros_like(self.ctx.galaxies.excess_b_v), 

398 coeffs=self.ctx.galaxies.template_coeffs_abs, 

399 filter_names=par.filters, 

400 ) 

401 

402 # Raise error is the number of galaxies per tile is too high or too low 

403 self.check_n_gal_prior(par) 

404 

405 # Set apparent (lensed) magnitudes, for now equal to intrinsic apparent 

406 # magnitudes 

407 self.ctx.galaxies.magnitude_dict = dict() 

408 for band, mag in self.ctx.galaxies.int_magnitude_dict.items(): 

409 self.ctx.galaxies.magnitude_dict[band] = mag.copy() 

410 

411 # Number of galaxies and id 

412 self.ctx.numgalaxies = self.ctx.galaxies.z.size 

413 self.ctx.galaxies.id = np.arange(self.ctx.numgalaxies) 

414 

415 # Backward compatibility 

416 self.ctx.galaxies.blue_red = np.ones(len(self.ctx.galaxies.z), dtype=np.ushort) 

417 self.ctx.galaxies.blue_red[ 

418 self.ctx.galaxies.galaxy_type == lum_funcs["blue"].galaxy_type 

419 ] = 1 

420 self.ctx.galaxies.blue_red[ 

421 self.ctx.galaxies.galaxy_type == lum_funcs["red"].galaxy_type 

422 ] = 0 

423 

424 LOGGER.info( 

425 f"galaxy counts n_total={self.ctx.numgalaxies}" 

426 f" mem_mb_current={utils.memory_usage_psutil():5.1f}" 

427 ) 

428 

429 if par.save_SEDs: 

430 # Store SEDs in the catalog 

431 restframe_wavelength, seds = get_seds(par, self.ctx.galaxies) 

432 self.ctx.restframe_wavelength_for_SED = restframe_wavelength 

433 self.ctx.galaxies.sed = seds 

434 try: 

435 del mag_calc.func.templates_int_table_dict 

436 del mag_calc.func.z_grid 

437 del mag_calc.func.excess_b_v_grid 

438 del mag_calc.func 

439 except Exception: 

440 pass 

441 # profile.print_stats(output_unit=1) 

442 

443 def __str__(self): 

444 return "sample gal photo" 

445 

446 

447def compute_templates_extinction_appmag_for_galaxies( 

448 galaxy_type, 

449 par, 

450 n_templates, 

451 cosmo, 

452 w, 

453 redshifts, 

454 absmags, 

455 x_pixel, 

456 y_pixel, 

457 mag_calc, 

458 extinction_eval, 

459): 

460 template_coeffs_abs = sample_template_coeff_lumfuncs( 

461 par=par, redshift_z={galaxy_type: redshifts}, n_templates=n_templates 

462 )[galaxy_type] 

463 

464 # Calculate absolute magnitudes according to coefficients and adjust 

465 # coefficients according to drawn magnitudes 

466 mag_z0 = mag_calc( 

467 redshifts=np.zeros_like(redshifts), 

468 excess_b_v=np.zeros_like(redshifts), 

469 coeffs=template_coeffs_abs, 

470 filter_names=[par.lum_fct_filter_band], 

471 ) 

472 

473 template_coeffs_abs *= np.expand_dims( 

474 10 ** (0.4 * (mag_z0[par.lum_fct_filter_band] - absmags)), -1 

475 ) 

476 

477 # Transform to apparent coefficients 

478 lum_dist = galaxy_sampling_util.apply_pycosmo_distfun( 

479 cosmo.background.dist_lum_a, redshifts 

480 ) 

481 template_coeffs = template_coeffs_abs * np.expand_dims( 

482 (10e-6 / lum_dist) ** 2 / (1 + redshifts), -1 

483 ) 

484 excess_b_v = extinction_eval(w, x_pixel, y_pixel) 

485 # TODO: fix this in the already in the creation that excess_b_v is always 

486 # array, even when n_gal=1 

487 if len(redshifts) == 1: 

488 excess_b_v = np.array([excess_b_v]) 

489 

490 # Calculate apparent reference band magnitude 

491 app_mag_ref = mag_calc( 

492 redshifts=redshifts, 

493 excess_b_v=excess_b_v, 

494 coeffs=template_coeffs, 

495 filter_names=[par.reference_band], 

496 )[par.reference_band] 

497 

498 return template_coeffs, template_coeffs_abs, excess_b_v, app_mag_ref 

499 

500 

501def in_pos(x, y, par): 

502 return (x > 0) & (x < par.size_x) & (y > 0) & (y < par.size_y) 

503 

504 

505def get_seds(par, galaxies): 

506 """ 

507 Get SEDs for galaxies in the catalog 

508 """ 

509 

510 direct_mag_calc = get_magnitude_calculator_direct(filter_names=par.filters, par=par) 

511 n_obj = galaxies.z.size 

512 seds = [] 

513 for i in LOGGER.progressbar( 

514 range(n_obj), 

515 desc="getting SEDs", 

516 at_level="debug", 

517 ): 

518 lam_obs_in_mu_m = direct_mag_calc.sed_templates["lam"] * ( 

519 1 + galaxies.z[i] 

520 ) # in micrometer 

521 spec = spectrum_util.construct_reddened_spectrum( 

522 lam_obs=lam_obs_in_mu_m, 

523 templates_amp=direct_mag_calc.sed_templates["amp"], 

524 coeff=galaxies.template_coeffs[i], 

525 excess_b_v=galaxies.excess_b_v[i], 

526 extinction_spline=direct_mag_calc.extinction_spline, 

527 ).flatten() # in erg/s/m2/Å 

528 # save in angstrom and erg/s/cm2/Å 

529 seds.append(spec / 1e4) 

530 seds = np.vstack(seds) 

531 return direct_mag_calc.sed_templates["lam"] * 1e4, seds