Coverage for src/galsbi/ucat/plugins/sample_galaxies_photo.py: 98%

169 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-10 11:12 +0000

1# Copyright (C) 2018 ETH Zurich, Institute for Particle Physics and Astrophysics 

2 

3""" 

4Created on Mar 5, 2018 

5author: Joerg Herbel 

6""" 

7 

8import os 

9import warnings 

10 

11import healpy as hp 

12import numpy as np 

13import PyCosmo 

14from astropy.coordinates import SkyCoord 

15from cosmic_toolbox import logger 

16from ivy.plugin.base_plugin import BasePlugin 

17from ufig import coordinate_util, io_util 

18 

19from galsbi.ucat import filters_util, galaxy_sampling_util, sed_templates_util, utils 

20from galsbi.ucat.filters_util import UseShortFilterNames 

21from galsbi.ucat.galaxy_population_models.galaxy_luminosity_function import ( 

22 initialize_luminosity_functions, 

23) 

24from galsbi.ucat.galaxy_population_models.galaxy_position import sample_position_uniform 

25from galsbi.ucat.galaxy_population_models.galaxy_sed import ( 

26 sample_template_coeff_lumfuncs, 

27) 

28 

29LOGGER = logger.get_logger(__file__) 

30SEED_OFFSET_LUMFUN = 123491 

31warnings.filterwarnings("once") 

32 

33 

34class ExtinctionMapEvaluator: 

35 """ 

36 Class that gives extinction values for positions 

37 """ 

38 

39 def __init__(self, par): 

40 if par.extinction_map_file_name is not None: 

41 extinction_map_file_name = io_util.get_abs_path( 

42 par.extinction_map_file_name, root_path=par.maps_remote_dir 

43 ) 

44 self.extinction_map = hp.read_map( 

45 extinction_map_file_name, nest=True, field=0 

46 ) 

47 

48 else: 

49 self.extinction_map = None 

50 

51 def __call__(self, wcs, x, y): 

52 if self.extinction_map is not None: 

53 if wcs is not None: 

54 ra, dec = coordinate_util.xy2radec(wcs, x, y) 

55 else: 

56 ra, dec = x, y 

57 sky_coord = SkyCoord(ra=ra, dec=dec, frame="icrs", unit="deg") 

58 gal_lon = sky_coord.galactic.l.deg 

59 gal_lat = sky_coord.galactic.b.deg 

60 theta, phi = coordinate_util.radec2thetaphi(gal_lon, gal_lat) 

61 excess_b_v = hp.get_interp_val( 

62 self.extinction_map, theta=theta, phi=phi, nest=True 

63 ) 

64 

65 else: 

66 excess_b_v = np.zeros_like(x) 

67 

68 return excess_b_v 

69 

70 

71def get_magnitude_calculator_direct(filter_names, par): 

72 """ 

73 Interface to direct magnitude calculation 

74 """ 

75 

76 filter_names_full = [par.filters_full_names[f] for f in filter_names] 

77 # TODO: par should be the path to the filters file, either change it here or there 

78 filepath_sed_integ = os.path.join(par.maps_remote_dir, par.filters_file_name) 

79 filters = filters_util.load_filters( 

80 filepath_sed_integ, filter_names=filter_names_full, lam_scale=1e-4 

81 ) 

82 

83 filepath_sed_templates = os.path.join(par.maps_remote_dir, par.templates_file_name) 

84 sed_templates = sed_templates_util.load_template_spectra( 

85 filepath_sed_templates, lam_scale=1e-4, amp_scale=1e4 

86 ) 

87 

88 from galsbi.ucat.magnitude_calculator import MagCalculatorDirect 

89 

90 return UseShortFilterNames( 

91 MagCalculatorDirect(filters, sed_templates), par.filters_full_names 

92 ) 

93 

94 

95def get_magnitude_calculator_table(filter_names, par): 

96 """ 

97 Interface to magnitude calculation with pre-computed tables 

98 """ 

99 

100 filepath_sed_integ = os.path.join( 

101 par.maps_remote_dir, par.templates_int_tables_file_name 

102 ) 

103 filter_full_names = [par.filters_full_names[f] for f in filter_names] 

104 

105 from galsbi.ucat.magnitude_calculator import MagCalculatorTable 

106 

107 return UseShortFilterNames( 

108 MagCalculatorTable( 

109 filter_full_names, 

110 filepath_sed_integ, 

111 copy_to_cwd=par.copy_template_int_tables_to_cwd, 

112 ), 

113 par.filters_full_names, 

114 ) 

115 

116 

117MAGNITUDES_CALCULATOR = { 

118 "direct": get_magnitude_calculator_direct, 

119 "table": get_magnitude_calculator_table, 

120} 

121 

122 

123class Plugin(BasePlugin): 

124 """ 

125 Generate a random catalog of galaxies with magnitudes in multiple bands. 

126 """ 

127 

128 def check_n_gal_prior(self, par): 

129 """ 

130 Check if the number of galaxies is inside the prior range, 

131 even before rendering, to remove extreme values. 

132 """ 

133 

134 if hasattr(par, "galaxy_count_prior"): 

135 app_mag = self.ctx.galaxies.int_magnitude_dict[ 

136 par.galaxy_count_prior["band"] 

137 ] 

138 n_gal_ = np.count_nonzero(app_mag < par.galaxy_count_prior["mag_max"]) 

139 n_gal_ /= par.ngal_multiplier 

140 LOGGER.info( 

141 ( 

142 "Allowed number galaxies with int_mag<{} per tile [{},{}]," 

143 " computed number: {}".format 

144 )( 

145 par.galaxy_count_prior["mag_max"], 

146 par.galaxy_count_prior["n_min"], 

147 par.galaxy_count_prior["n_max"], 

148 n_gal_, 

149 ) 

150 ) 

151 if (n_gal_ < par.galaxy_count_prior["n_min"]) or ( 

152 n_gal_ > par.galaxy_count_prior["n_max"] 

153 ): 

154 raise galaxy_sampling_util.UCatNumGalError( 

155 "too many or too few galaxies" 

156 ) 

157 

158 def check_max_mem_error(self, par, max_mem_hard_limit=10000): 

159 """ 

160 Check if the catalog does not exceed allowed memory. 

161 Prevents job crashes on clusters. 

162 """ 

163 

164 mem_mb_current = utils.memory_usage_psutil() 

165 if mem_mb_current > max_mem_hard_limit: 

166 raise galaxy_sampling_util.UCatNumGalError( 

167 "The sample_galaxies process is taking too much memory:" 

168 f" mem_mb_current={mem_mb_current}," 

169 f" max_mem_hard_limit={max_mem_hard_limit}" 

170 ) 

171 

172 def __call__(self): 

173 par = self.ctx.parameters 

174 

175 # Cosmology 

176 cosmo = PyCosmo.build() 

177 cosmo.set(h=par.h, omega_m=par.omega_m) 

178 

179 if par.sampling_mode == "wcs": 

180 LOGGER.info("Sampling galaxies based on RA/DEC and pixel scale") 

181 # Healpix pixelization 

182 w = coordinate_util.wcs_from_parameters(par) 

183 self.ctx.pixels = coordinate_util.get_healpix_pixels( 

184 par.nside_sampling, w, par.size_x, par.size_y 

185 ) 

186 if len(self.ctx.pixels) < 15: 

187 LOGGER.warning( 

188 f"Only {len(self.ctx.pixels)} healpy pixels in the footprint," 

189 " consider increasing the nside_sampling" 

190 ) 

191 elif par.sampling_mode == "healpix": 

192 LOGGER.info("Sampling galaxies based on healpix pixels") 

193 self.ctx.pixels = coordinate_util.get_healpix_pixels_from_map(par) 

194 w = None 

195 else: 

196 raise ValueError( 

197 f"Unknown sampling mode: {par.sampling_mode}, must be wcs or healpix" 

198 ) 

199 self.ctx.pixarea = hp.nside2pixarea(par.nside_sampling, degrees=False) 

200 

201 # Magnitude calculator 

202 all_filters = np.unique(par.filters + [par.lum_fct_filter_band]) 

203 

204 # backward compatibility - check if full filter names are set 

205 if not hasattr(par, "filters_full_names"): 

206 par.filters_full_names = filters_util.get_default_full_filter_names( 

207 all_filters 

208 ) 

209 warnings.warn( 

210 "setting filters to default, this will cause problems if you work" 

211 " with filters from different cameras in the same band", 

212 stacklevel=1, 

213 ) 

214 

215 # get magnitude calculators (reload cache to avoid memory leaks and save memory) 

216 mag_calc = MAGNITUDES_CALCULATOR[par.magnitude_calculation]( 

217 filter_names=all_filters, par=par 

218 ) 

219 n_templates = mag_calc.n_templates 

220 # Cut in z - M - plane & boundaries 

221 z_m_intp = galaxy_sampling_util.intp_z_m_cut(cosmo, mag_calc, par) 

222 

223 # Initialize galaxy catalog 

224 self.ctx.galaxies = galaxy_sampling_util.Catalog() 

225 self.ctx.galaxies.columns = [ 

226 "id", 

227 "z", 

228 "template_coeffs", 

229 "template_coeffs_abs", 

230 "abs_mag_lumfun", 

231 "galaxy_type", 

232 "excess_b_v", 

233 ] 

234 

235 # Columns modified inside loop 

236 loop_cols = [ 

237 "z", 

238 "template_coeffs", 

239 "template_coeffs_abs", 

240 "abs_mag_lumfun", 

241 "galaxy_type", 

242 "excess_b_v", 

243 ] 

244 if w is not None: 

245 loop_cols += ["x", "y"] 

246 self.ctx.galaxies.columns += ["x", "y"] 

247 else: 

248 loop_cols += ["ra", "dec"] 

249 self.ctx.galaxies.columns += ["ra", "dec"] 

250 for c in loop_cols: 

251 setattr(self.ctx.galaxies, c, []) 

252 

253 # set up luminosity functions 

254 lum_funcs = initialize_luminosity_functions( 

255 par, cosmo=cosmo, pixarea=self.ctx.pixarea, z_m_intp=z_m_intp 

256 ) 

257 

258 # Extinction 

259 extinction_eval = ExtinctionMapEvaluator(par) 

260 

261 # Helper function to compute templates, extinction 

262 for g in par.galaxy_types: 

263 n_gal_type = 0 

264 n_gal_type_max = getattr(par, f"n_gal_max_{g}") 

265 max_reached = False 

266 

267 for i in LOGGER.progressbar( 

268 range(len(self.ctx.pixels)), 

269 desc=f"getting {g:<4s} galaxies for healpix pixels", 

270 at_level="debug", 

271 ): 

272 # Sample absolute mag vs redshift from luminosity function 

273 abs_mag, z = lum_funcs[g].sample_z_mabs_and_apply_cut( 

274 seed_ngal=par.seed 

275 + self.ctx.pixels[i] 

276 + par.gal_num_seed_offset 

277 + SEED_OFFSET_LUMFUN * i, 

278 seed_lumfun=par.seed 

279 + self.ctx.pixels[i] 

280 + par.gal_lum_fct_seed_offset 

281 + SEED_OFFSET_LUMFUN * i, 

282 n_gal_max=n_gal_type_max, 

283 ) 

284 

285 # Positions 

286 np.random.seed(par.seed + self.ctx.pixels[i] + par.gal_dist_seed_offset) 

287 

288 # x and y for wcs, ra and dec for healpix model 

289 x, y = sample_position_uniform( 

290 len(z), w, self.ctx.pixels[i], par.nside_sampling 

291 ) 

292 

293 # Make the catalog precision already here to avoid inconsistencies in 

294 # the selection 

295 x = x.astype(par.catalog_precision) 

296 y = y.astype(par.catalog_precision) 

297 

298 ( 

299 template_coeffs, 

300 template_coeffs_abs, 

301 excess_b_v, 

302 app_mag_ref, 

303 ) = compute_templates_extinction_appmag_for_galaxies( 

304 galaxy_type=g, 

305 par=par, 

306 n_templates=n_templates, 

307 cosmo=cosmo, 

308 w=w, 

309 redshifts=z, 

310 absmags=abs_mag, 

311 x_pixel=x, 

312 y_pixel=y, 

313 mag_calc=mag_calc, 

314 extinction_eval=extinction_eval, 

315 ) 

316 

317 # Reject galaxies outside set magnitude range 

318 select_mag_range = (app_mag_ref >= par.gals_mag_min) & ( 

319 app_mag_ref <= par.gals_mag_max 

320 ) 

321 if w is not None: 

322 select_pos_range = in_pos(x, y, par) 

323 else: 

324 select_pos_range = np.ones_like(x, dtype=bool) 

325 select = select_mag_range & select_pos_range 

326 n_gal = np.count_nonzero(select) 

327 n_gal_type += n_gal 

328 

329 # store 

330 if w is not None: 

331 self.ctx.galaxies.x.append(x[select].astype(par.catalog_precision)) 

332 self.ctx.galaxies.y.append(y[select].astype(par.catalog_precision)) 

333 else: 

334 self.ctx.galaxies.ra.append(x[select].astype(par.catalog_precision)) 

335 self.ctx.galaxies.dec.append( 

336 y[select].astype(par.catalog_precision) 

337 ) 

338 self.ctx.galaxies.z.append(z[select].astype(par.catalog_precision)) 

339 self.ctx.galaxies.template_coeffs.append( 

340 template_coeffs[select].astype(par.catalog_precision) 

341 ) 

342 self.ctx.galaxies.template_coeffs_abs.append( 

343 template_coeffs_abs[select].astype(par.catalog_precision) 

344 ) 

345 self.ctx.galaxies.abs_mag_lumfun.append( 

346 abs_mag[select].astype(par.catalog_precision) 

347 ) 

348 self.ctx.galaxies.galaxy_type.append( 

349 np.ones(n_gal, dtype=np.ushort) * lum_funcs[g].galaxy_type 

350 ) 

351 self.ctx.galaxies.excess_b_v.append( 

352 excess_b_v[select].astype(par.catalog_precision) 

353 ) 

354 

355 # see if number of galaxies is OK 

356 if n_gal_type > n_gal_type_max * par.ngal_multiplier: 

357 max_reached = True 

358 if par.raise_max_num_gal_error: 

359 raise galaxy_sampling_util.UCatNumGalError( 

360 "exceeded number of" 

361 f" {g} galaxies {n_gal_type}>{n_gal_type_max}" 

362 ) 

363 else: 

364 break 

365 

366 LOGGER.info( 

367 f"lumfun={g} n_gals={n_gal_type} maximum number of galaxies" 

368 f" reached={max_reached} ({n_gal_type_max})" 

369 ) 

370 

371 # Common code shared between clustering and uniform position methods 

372 

373 # check memory footprint 

374 self.check_max_mem_error(par) 

375 

376 # Concatenate columns 

377 for c in loop_cols: 

378 setattr(self.ctx.galaxies, c, np.concatenate(getattr(self.ctx.galaxies, c))) 

379 

380 # Calculate requested intrinsic apparent and absolute magnitudes 

381 self.ctx.galaxies.int_magnitude_dict = mag_calc( 

382 redshifts=self.ctx.galaxies.z, 

383 excess_b_v=self.ctx.galaxies.excess_b_v, 

384 coeffs=self.ctx.galaxies.template_coeffs, 

385 filter_names=par.filters, 

386 ) 

387 

388 self.ctx.galaxies.abs_magnitude_dict = mag_calc( 

389 redshifts=np.zeros_like(self.ctx.galaxies.z), 

390 excess_b_v=np.zeros_like(self.ctx.galaxies.excess_b_v), 

391 coeffs=self.ctx.galaxies.template_coeffs_abs, 

392 filter_names=par.filters, 

393 ) 

394 

395 # Raise error is the number of galaxies per tile is too high or too low 

396 self.check_n_gal_prior(par) 

397 

398 # Set apparent (lensed) magnitudes, for now equal to intrinsic apparent 

399 # magnitudes 

400 self.ctx.galaxies.magnitude_dict = dict() 

401 for band, mag in self.ctx.galaxies.int_magnitude_dict.items(): 

402 self.ctx.galaxies.magnitude_dict[band] = mag.copy() 

403 

404 # Number of galaxies and id 

405 self.ctx.numgalaxies = self.ctx.galaxies.z.size 

406 self.ctx.galaxies.id = np.arange(self.ctx.numgalaxies) 

407 

408 # Backward compatibility 

409 self.ctx.galaxies.blue_red = np.ones(len(self.ctx.galaxies.z), dtype=np.ushort) 

410 self.ctx.galaxies.blue_red[ 

411 self.ctx.galaxies.galaxy_type == lum_funcs["blue"].galaxy_type 

412 ] = 1 

413 self.ctx.galaxies.blue_red[ 

414 self.ctx.galaxies.galaxy_type == lum_funcs["red"].galaxy_type 

415 ] = 0 

416 

417 LOGGER.info( 

418 f"galaxy counts n_total={self.ctx.numgalaxies}" 

419 f" mem_mb_current={utils.memory_usage_psutil():5.1f}" 

420 ) 

421 try: 

422 del mag_calc.func.templates_int_table_dict 

423 del mag_calc.func.z_grid 

424 del mag_calc.func.excess_b_v_grid 

425 del mag_calc.func 

426 except Exception: 

427 pass 

428 # profile.print_stats(output_unit=1) 

429 

430 def __str__(self): 

431 return "sample gal photo" 

432 

433 

434def compute_templates_extinction_appmag_for_galaxies( 

435 galaxy_type, 

436 par, 

437 n_templates, 

438 cosmo, 

439 w, 

440 redshifts, 

441 absmags, 

442 x_pixel, 

443 y_pixel, 

444 mag_calc, 

445 extinction_eval, 

446): 

447 template_coeffs_abs = sample_template_coeff_lumfuncs( 

448 par=par, redshift_z={galaxy_type: redshifts}, n_templates=n_templates 

449 )[galaxy_type] 

450 

451 # Calculate absolute magnitudes according to coefficients and adjust 

452 # coefficients according to drawn magnitudes 

453 mag_z0 = mag_calc( 

454 redshifts=np.zeros_like(redshifts), 

455 excess_b_v=np.zeros_like(redshifts), 

456 coeffs=template_coeffs_abs, 

457 filter_names=[par.lum_fct_filter_band], 

458 ) 

459 

460 template_coeffs_abs *= np.expand_dims( 

461 10 ** (0.4 * (mag_z0[par.lum_fct_filter_band] - absmags)), -1 

462 ) 

463 

464 # Transform to apparent coefficients 

465 lum_dist = galaxy_sampling_util.apply_pycosmo_distfun( 

466 cosmo.background.dist_lum_a, redshifts 

467 ) 

468 template_coeffs = template_coeffs_abs * np.expand_dims( 

469 (10e-6 / lum_dist) ** 2 / (1 + redshifts), -1 

470 ) 

471 excess_b_v = extinction_eval(w, x_pixel, y_pixel) 

472 # TODO: fix this in the already in the creation that excess_b_v is always 

473 # array, even when n_gal=1 

474 if len(redshifts) == 1: 

475 excess_b_v = np.array([excess_b_v]) 

476 

477 # Calculate apparent reference band magnitude 

478 app_mag_ref = mag_calc( 

479 redshifts=redshifts, 

480 excess_b_v=excess_b_v, 

481 coeffs=template_coeffs, 

482 filter_names=[par.reference_band], 

483 )[par.reference_band] 

484 

485 return template_coeffs, template_coeffs_abs, excess_b_v, app_mag_ref 

486 

487 

488def in_pos(x, y, par): 

489 return (x > 0) & (x < par.size_x) & (y > 0) & (y < par.size_y)