Coverage for src/galsbi/ucat/plugins/sample_galaxies_photo.py: 98%
183 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-26 10:41 +0000
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-26 10:41 +0000
1# Copyright (C) 2018 ETH Zurich, Institute for Particle Physics and Astrophysics
3"""
4Created on Mar 5, 2018
5author: Joerg Herbel
6"""
8import os
9import warnings
11import healpy as hp
12import numpy as np
13import PyCosmo
14from astropy.coordinates import SkyCoord
15from cosmic_toolbox import logger
16from ivy.plugin.base_plugin import BasePlugin
17from ufig import coordinate_util, io_util
19from galsbi.ucat import (
20 filters_util,
21 galaxy_sampling_util,
22 sed_templates_util,
23 spectrum_util,
24 utils,
25)
26from galsbi.ucat.filters_util import UseShortFilterNames
27from galsbi.ucat.galaxy_population_models.galaxy_luminosity_function import (
28 initialize_luminosity_functions,
29)
30from galsbi.ucat.galaxy_population_models.galaxy_position import sample_position_uniform
31from galsbi.ucat.galaxy_population_models.galaxy_sed import (
32 sample_template_coeff_lumfuncs,
33)
35LOGGER = logger.get_logger(__file__)
36SEED_OFFSET_LUMFUN = 123491
37warnings.filterwarnings("once")
40class ExtinctionMapEvaluator:
41 """
42 Class that gives extinction values for positions
43 """
45 def __init__(self, par):
46 if par.extinction_map_file_name is not None:
47 extinction_map_file_name = io_util.get_abs_path(
48 par.extinction_map_file_name, root_path=par.maps_remote_dir
49 )
50 self.extinction_map = hp.read_map(
51 extinction_map_file_name, nest=True, field=0
52 )
54 else:
55 self.extinction_map = None
57 def __call__(self, wcs, x, y):
58 if self.extinction_map is not None:
59 if wcs is not None:
60 ra, dec = coordinate_util.xy2radec(wcs, x, y)
61 else:
62 ra, dec = x, y
63 sky_coord = SkyCoord(ra=ra, dec=dec, frame="icrs", unit="deg")
64 gal_lon = sky_coord.galactic.l.deg
65 gal_lat = sky_coord.galactic.b.deg
66 theta, phi = coordinate_util.radec2thetaphi(gal_lon, gal_lat)
67 excess_b_v = hp.get_interp_val(
68 self.extinction_map, theta=theta, phi=phi, nest=True
69 )
71 else:
72 excess_b_v = np.zeros_like(x)
74 return excess_b_v
77def get_magnitude_calculator_direct(filter_names, par):
78 """
79 Interface to direct magnitude calculation
80 """
82 filter_names_full = [par.filters_full_names[f] for f in filter_names]
83 # TODO: par should be the path to the filters file, either change it here or there
84 filepath_sed_integ = os.path.join(par.maps_remote_dir, par.filters_file_name)
85 filters = filters_util.load_filters(
86 filepath_sed_integ, filter_names=filter_names_full, lam_scale=1e-4
87 )
89 filepath_sed_templates = os.path.join(par.maps_remote_dir, par.templates_file_name)
90 # Load SED templates with following units:
91 # Lambda: micrometer
92 # SED: erg/s/m2/Å
93 sed_templates = sed_templates_util.load_template_spectra(
94 filepath_sed_templates, lam_scale=1e-4, amp_scale=1e4
95 )
97 from galsbi.ucat.magnitude_calculator import MagCalculatorDirect
99 return UseShortFilterNames(
100 MagCalculatorDirect(filters, sed_templates), par.filters_full_names
101 )
104def get_magnitude_calculator_table(filter_names, par):
105 """
106 Interface to magnitude calculation with pre-computed tables
107 """
109 filepath_sed_integ = os.path.join(
110 par.maps_remote_dir, par.templates_int_tables_file_name
111 )
112 filter_full_names = [par.filters_full_names[f] for f in filter_names]
114 from galsbi.ucat.magnitude_calculator import MagCalculatorTable
116 return UseShortFilterNames(
117 MagCalculatorTable(
118 filter_full_names,
119 filepath_sed_integ,
120 copy_to_cwd=par.copy_template_int_tables_to_cwd,
121 ),
122 par.filters_full_names,
123 )
126MAGNITUDES_CALCULATOR = {
127 "direct": get_magnitude_calculator_direct,
128 "table": get_magnitude_calculator_table,
129}
132class Plugin(BasePlugin):
133 """
134 Generate a random catalog of galaxies with magnitudes in multiple bands.
135 """
137 def check_n_gal_prior(self, par):
138 """
139 Check if the number of galaxies is inside the prior range,
140 even before rendering, to remove extreme values.
141 """
143 if hasattr(par, "galaxy_count_prior"):
144 app_mag = self.ctx.galaxies.int_magnitude_dict[
145 par.galaxy_count_prior["band"]
146 ]
147 n_gal_ = np.count_nonzero(app_mag < par.galaxy_count_prior["mag_max"])
148 n_gal_ /= par.ngal_multiplier
149 LOGGER.info(
150 (
151 "Allowed number galaxies with int_mag<{} per tile [{},{}],"
152 " computed number: {}".format
153 )(
154 par.galaxy_count_prior["mag_max"],
155 par.galaxy_count_prior["n_min"],
156 par.galaxy_count_prior["n_max"],
157 n_gal_,
158 )
159 )
160 if (n_gal_ < par.galaxy_count_prior["n_min"]) or (
161 n_gal_ > par.galaxy_count_prior["n_max"]
162 ):
163 raise galaxy_sampling_util.UCatNumGalError(
164 "too many or too few galaxies"
165 )
167 def check_max_mem_error(self, par):
168 """
169 Check if the catalog does not exceed allowed memory.
170 Prevents job crashes on clusters.
171 """
173 mem_mb_current = utils.memory_usage_psutil()
174 if mem_mb_current > par.max_memlimit_gal_catalog:
175 raise galaxy_sampling_util.UCatNumGalError(
176 "The sample_galaxies process is taking too much memory:"
177 f" mem_mb_current={mem_mb_current},"
178 f" max_mem_hard_limit={par.max_memlimit_gal_catalog}"
179 )
181 def __call__(self):
182 par = self.ctx.parameters
184 # Cosmology
185 cosmo = PyCosmo.build()
186 cosmo.set(h=par.h, omega_m=par.omega_m)
188 if par.sampling_mode == "wcs":
189 LOGGER.info("Sampling galaxies based on RA/DEC and pixel scale")
190 # Healpix pixelization
191 w = coordinate_util.wcs_from_parameters(par)
192 self.ctx.pixels = coordinate_util.get_healpix_pixels(
193 par.nside_sampling, w, par.size_x, par.size_y
194 )
195 if len(self.ctx.pixels) < 15:
196 LOGGER.warning(
197 f"Only {len(self.ctx.pixels)} healpy pixels in the footprint,"
198 " consider increasing the nside_sampling"
199 )
200 elif par.sampling_mode == "healpix":
201 LOGGER.info("Sampling galaxies based on healpix pixels")
202 self.ctx.pixels = coordinate_util.get_healpix_pixels_from_map(par)
203 w = None
204 else:
205 raise ValueError(
206 f"Unknown sampling mode: {par.sampling_mode}, must be wcs or healpix"
207 )
208 self.ctx.pixarea = hp.nside2pixarea(par.nside_sampling, degrees=False)
210 # Magnitude calculator
211 all_filters = np.unique(par.filters + [par.lum_fct_filter_band])
213 # backward compatibility - check if full filter names are set
214 if not hasattr(par, "filters_full_names"):
215 par.filters_full_names = filters_util.get_default_full_filter_names(
216 all_filters
217 )
218 warnings.warn(
219 "setting filters to default, this will cause problems if you work"
220 " with filters from different cameras in the same band",
221 stacklevel=1,
222 )
224 # get magnitude calculators (reload cache to avoid memory leaks and save memory)
225 mag_calc = MAGNITUDES_CALCULATOR[par.magnitude_calculation](
226 filter_names=all_filters, par=par
227 )
228 n_templates = mag_calc.n_templates
229 # Cut in z - M - plane & boundaries
230 z_m_intp = galaxy_sampling_util.intp_z_m_cut(cosmo, mag_calc, par)
232 # Initialize galaxy catalog
233 self.ctx.galaxies = galaxy_sampling_util.Catalog()
234 self.ctx.galaxies.columns = [
235 "id",
236 "z",
237 "template_coeffs",
238 "template_coeffs_abs",
239 "abs_mag_lumfun",
240 "galaxy_type",
241 "excess_b_v",
242 ]
244 # Columns modified inside loop
245 loop_cols = [
246 "z",
247 "template_coeffs",
248 "template_coeffs_abs",
249 "abs_mag_lumfun",
250 "galaxy_type",
251 "excess_b_v",
252 ]
253 if w is not None:
254 loop_cols += ["x", "y"]
255 self.ctx.galaxies.columns += ["x", "y"]
256 else:
257 loop_cols += ["ra", "dec"]
258 self.ctx.galaxies.columns += ["ra", "dec"]
259 for c in loop_cols:
260 setattr(self.ctx.galaxies, c, [])
262 # set up luminosity functions
263 lum_funcs = initialize_luminosity_functions(
264 par, cosmo=cosmo, pixarea=self.ctx.pixarea, z_m_intp=z_m_intp
265 )
267 # Extinction
268 extinction_eval = ExtinctionMapEvaluator(par)
270 # Helper function to compute templates, extinction
271 for g in par.galaxy_types:
272 n_gal_type = 0
273 n_gal_type_max = getattr(par, f"n_gal_max_{g}")
274 max_reached = False
276 for i in LOGGER.progressbar(
277 range(len(self.ctx.pixels)),
278 desc=f"getting {g:<4s} galaxies for healpix pixels",
279 at_level="debug",
280 ):
281 # Sample absolute mag vs redshift from luminosity function
282 abs_mag, z = lum_funcs[g].sample_z_mabs_and_apply_cut(
283 seed_ngal=par.seed
284 + self.ctx.pixels[i]
285 + par.gal_num_seed_offset
286 + SEED_OFFSET_LUMFUN * i,
287 seed_lumfun=par.seed
288 + self.ctx.pixels[i]
289 + par.gal_lum_fct_seed_offset
290 + SEED_OFFSET_LUMFUN * i,
291 n_gal_max=n_gal_type_max,
292 )
294 # Positions
295 np.random.seed(par.seed + self.ctx.pixels[i] + par.gal_dist_seed_offset)
297 # x and y for wcs, ra and dec for healpix model
298 x, y = sample_position_uniform(
299 len(z), w, self.ctx.pixels[i], par.nside_sampling
300 )
302 # Make the catalog precision already here to avoid inconsistencies in
303 # the selection
304 x = x.astype(par.catalog_precision)
305 y = y.astype(par.catalog_precision)
307 (
308 template_coeffs,
309 template_coeffs_abs,
310 excess_b_v,
311 app_mag_ref,
312 ) = compute_templates_extinction_appmag_for_galaxies(
313 galaxy_type=g,
314 par=par,
315 n_templates=n_templates,
316 cosmo=cosmo,
317 w=w,
318 redshifts=z,
319 absmags=abs_mag,
320 x_pixel=x,
321 y_pixel=y,
322 mag_calc=mag_calc,
323 extinction_eval=extinction_eval,
324 )
326 # Reject galaxies outside set magnitude range
327 select_mag_range = (app_mag_ref >= par.gals_mag_min) & (
328 app_mag_ref <= par.gals_mag_max
329 )
330 if w is not None:
331 select_pos_range = in_pos(x, y, par)
332 else:
333 select_pos_range = np.ones_like(x, dtype=bool)
334 select = select_mag_range & select_pos_range
335 n_gal = np.count_nonzero(select)
336 n_gal_type += n_gal
338 # store
339 if w is not None:
340 self.ctx.galaxies.x.append(x[select].astype(par.catalog_precision))
341 self.ctx.galaxies.y.append(y[select].astype(par.catalog_precision))
342 else:
343 self.ctx.galaxies.ra.append(x[select].astype(par.catalog_precision))
344 self.ctx.galaxies.dec.append(
345 y[select].astype(par.catalog_precision)
346 )
347 self.ctx.galaxies.z.append(z[select].astype(par.catalog_precision))
348 self.ctx.galaxies.template_coeffs.append(
349 template_coeffs[select].astype(par.catalog_precision)
350 )
351 self.ctx.galaxies.template_coeffs_abs.append(
352 template_coeffs_abs[select].astype(par.catalog_precision)
353 )
354 self.ctx.galaxies.abs_mag_lumfun.append(
355 abs_mag[select].astype(par.catalog_precision)
356 )
357 self.ctx.galaxies.galaxy_type.append(
358 np.ones(n_gal, dtype=np.ushort) * lum_funcs[g].galaxy_type
359 )
360 self.ctx.galaxies.excess_b_v.append(
361 excess_b_v[select].astype(par.catalog_precision)
362 )
364 # see if number of galaxies is OK
365 if n_gal_type > n_gal_type_max * par.ngal_multiplier:
366 max_reached = True
367 if par.raise_max_num_gal_error:
368 raise galaxy_sampling_util.UCatNumGalError(
369 "exceeded number of"
370 f" {g} galaxies {n_gal_type}>{n_gal_type_max}"
371 )
372 else:
373 break
375 LOGGER.info(
376 f"lumfun={g} n_gals={n_gal_type} maximum number of galaxies"
377 f" reached={max_reached} ({n_gal_type_max})"
378 )
380 # check memory footprint
381 self.check_max_mem_error(par)
383 # Concatenate columns
384 for c in loop_cols:
385 setattr(self.ctx.galaxies, c, np.concatenate(getattr(self.ctx.galaxies, c)))
387 # Calculate requested intrinsic apparent and absolute magnitudes
388 self.ctx.galaxies.int_magnitude_dict = mag_calc(
389 redshifts=self.ctx.galaxies.z,
390 excess_b_v=self.ctx.galaxies.excess_b_v,
391 coeffs=self.ctx.galaxies.template_coeffs,
392 filter_names=par.filters,
393 )
395 self.ctx.galaxies.abs_magnitude_dict = mag_calc(
396 redshifts=np.zeros_like(self.ctx.galaxies.z),
397 excess_b_v=np.zeros_like(self.ctx.galaxies.excess_b_v),
398 coeffs=self.ctx.galaxies.template_coeffs_abs,
399 filter_names=par.filters,
400 )
402 # Raise error is the number of galaxies per tile is too high or too low
403 self.check_n_gal_prior(par)
405 # Set apparent (lensed) magnitudes, for now equal to intrinsic apparent
406 # magnitudes
407 self.ctx.galaxies.magnitude_dict = dict()
408 for band, mag in self.ctx.galaxies.int_magnitude_dict.items():
409 self.ctx.galaxies.magnitude_dict[band] = mag.copy()
411 # Number of galaxies and id
412 self.ctx.numgalaxies = self.ctx.galaxies.z.size
413 self.ctx.galaxies.id = np.arange(self.ctx.numgalaxies)
415 # Backward compatibility
416 self.ctx.galaxies.blue_red = np.ones(len(self.ctx.galaxies.z), dtype=np.ushort)
417 self.ctx.galaxies.blue_red[
418 self.ctx.galaxies.galaxy_type == lum_funcs["blue"].galaxy_type
419 ] = 1
420 self.ctx.galaxies.blue_red[
421 self.ctx.galaxies.galaxy_type == lum_funcs["red"].galaxy_type
422 ] = 0
424 LOGGER.info(
425 f"galaxy counts n_total={self.ctx.numgalaxies}"
426 f" mem_mb_current={utils.memory_usage_psutil():5.1f}"
427 )
429 if par.save_SEDs:
430 # Store SEDs in the catalog
431 restframe_wavelength, seds = get_seds(par, self.ctx.galaxies)
432 self.ctx.restframe_wavelength_for_SED = restframe_wavelength
433 self.ctx.galaxies.sed = seds
434 try:
435 del mag_calc.func.templates_int_table_dict
436 del mag_calc.func.z_grid
437 del mag_calc.func.excess_b_v_grid
438 del mag_calc.func
439 except Exception:
440 pass
441 # profile.print_stats(output_unit=1)
443 def __str__(self):
444 return "sample gal photo"
447def compute_templates_extinction_appmag_for_galaxies(
448 galaxy_type,
449 par,
450 n_templates,
451 cosmo,
452 w,
453 redshifts,
454 absmags,
455 x_pixel,
456 y_pixel,
457 mag_calc,
458 extinction_eval,
459):
460 template_coeffs_abs = sample_template_coeff_lumfuncs(
461 par=par, redshift_z={galaxy_type: redshifts}, n_templates=n_templates
462 )[galaxy_type]
464 # Calculate absolute magnitudes according to coefficients and adjust
465 # coefficients according to drawn magnitudes
466 mag_z0 = mag_calc(
467 redshifts=np.zeros_like(redshifts),
468 excess_b_v=np.zeros_like(redshifts),
469 coeffs=template_coeffs_abs,
470 filter_names=[par.lum_fct_filter_band],
471 )
473 template_coeffs_abs *= np.expand_dims(
474 10 ** (0.4 * (mag_z0[par.lum_fct_filter_band] - absmags)), -1
475 )
477 # Transform to apparent coefficients
478 lum_dist = galaxy_sampling_util.apply_pycosmo_distfun(
479 cosmo.background.dist_lum_a, redshifts
480 )
481 template_coeffs = template_coeffs_abs * np.expand_dims(
482 (10e-6 / lum_dist) ** 2 / (1 + redshifts), -1
483 )
484 excess_b_v = extinction_eval(w, x_pixel, y_pixel)
485 # TODO: fix this in the already in the creation that excess_b_v is always
486 # array, even when n_gal=1
487 if len(redshifts) == 1:
488 excess_b_v = np.array([excess_b_v])
490 # Calculate apparent reference band magnitude
491 app_mag_ref = mag_calc(
492 redshifts=redshifts,
493 excess_b_v=excess_b_v,
494 coeffs=template_coeffs,
495 filter_names=[par.reference_band],
496 )[par.reference_band]
498 return template_coeffs, template_coeffs_abs, excess_b_v, app_mag_ref
501def in_pos(x, y, par):
502 return (x > 0) & (x < par.size_x) & (y > 0) & (y < par.size_y)
505def get_seds(par, galaxies):
506 """
507 Get SEDs for galaxies in the catalog
508 """
510 direct_mag_calc = get_magnitude_calculator_direct(filter_names=par.filters, par=par)
511 n_obj = galaxies.z.size
512 seds = []
513 for i in LOGGER.progressbar(
514 range(n_obj),
515 desc="getting SEDs",
516 at_level="debug",
517 ):
518 lam_obs_in_mu_m = direct_mag_calc.sed_templates["lam"] * (
519 1 + galaxies.z[i]
520 ) # in micrometer
521 spec = spectrum_util.construct_reddened_spectrum(
522 lam_obs=lam_obs_in_mu_m,
523 templates_amp=direct_mag_calc.sed_templates["amp"],
524 coeff=galaxies.template_coeffs[i],
525 excess_b_v=galaxies.excess_b_v[i],
526 extinction_spline=direct_mag_calc.extinction_spline,
527 ).flatten() # in erg/s/m2/Å
528 # save in angstrom and erg/s/cm2/Å
529 seds.append(spec / 1e4)
530 seds = np.vstack(seds)
531 return direct_mag_calc.sed_templates["lam"] * 1e4, seds