Coverage for src/galsbi/ucat/plugins/sample_galaxies_photo.py: 98%
169 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-10 11:12 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-10 11:12 +0000
1# Copyright (C) 2018 ETH Zurich, Institute for Particle Physics and Astrophysics
3"""
4Created on Mar 5, 2018
5author: Joerg Herbel
6"""
8import os
9import warnings
11import healpy as hp
12import numpy as np
13import PyCosmo
14from astropy.coordinates import SkyCoord
15from cosmic_toolbox import logger
16from ivy.plugin.base_plugin import BasePlugin
17from ufig import coordinate_util, io_util
19from galsbi.ucat import filters_util, galaxy_sampling_util, sed_templates_util, utils
20from galsbi.ucat.filters_util import UseShortFilterNames
21from galsbi.ucat.galaxy_population_models.galaxy_luminosity_function import (
22 initialize_luminosity_functions,
23)
24from galsbi.ucat.galaxy_population_models.galaxy_position import sample_position_uniform
25from galsbi.ucat.galaxy_population_models.galaxy_sed import (
26 sample_template_coeff_lumfuncs,
27)
29LOGGER = logger.get_logger(__file__)
30SEED_OFFSET_LUMFUN = 123491
31warnings.filterwarnings("once")
34class ExtinctionMapEvaluator:
35 """
36 Class that gives extinction values for positions
37 """
39 def __init__(self, par):
40 if par.extinction_map_file_name is not None:
41 extinction_map_file_name = io_util.get_abs_path(
42 par.extinction_map_file_name, root_path=par.maps_remote_dir
43 )
44 self.extinction_map = hp.read_map(
45 extinction_map_file_name, nest=True, field=0
46 )
48 else:
49 self.extinction_map = None
51 def __call__(self, wcs, x, y):
52 if self.extinction_map is not None:
53 if wcs is not None:
54 ra, dec = coordinate_util.xy2radec(wcs, x, y)
55 else:
56 ra, dec = x, y
57 sky_coord = SkyCoord(ra=ra, dec=dec, frame="icrs", unit="deg")
58 gal_lon = sky_coord.galactic.l.deg
59 gal_lat = sky_coord.galactic.b.deg
60 theta, phi = coordinate_util.radec2thetaphi(gal_lon, gal_lat)
61 excess_b_v = hp.get_interp_val(
62 self.extinction_map, theta=theta, phi=phi, nest=True
63 )
65 else:
66 excess_b_v = np.zeros_like(x)
68 return excess_b_v
71def get_magnitude_calculator_direct(filter_names, par):
72 """
73 Interface to direct magnitude calculation
74 """
76 filter_names_full = [par.filters_full_names[f] for f in filter_names]
77 # TODO: par should be the path to the filters file, either change it here or there
78 filepath_sed_integ = os.path.join(par.maps_remote_dir, par.filters_file_name)
79 filters = filters_util.load_filters(
80 filepath_sed_integ, filter_names=filter_names_full, lam_scale=1e-4
81 )
83 filepath_sed_templates = os.path.join(par.maps_remote_dir, par.templates_file_name)
84 sed_templates = sed_templates_util.load_template_spectra(
85 filepath_sed_templates, lam_scale=1e-4, amp_scale=1e4
86 )
88 from galsbi.ucat.magnitude_calculator import MagCalculatorDirect
90 return UseShortFilterNames(
91 MagCalculatorDirect(filters, sed_templates), par.filters_full_names
92 )
95def get_magnitude_calculator_table(filter_names, par):
96 """
97 Interface to magnitude calculation with pre-computed tables
98 """
100 filepath_sed_integ = os.path.join(
101 par.maps_remote_dir, par.templates_int_tables_file_name
102 )
103 filter_full_names = [par.filters_full_names[f] for f in filter_names]
105 from galsbi.ucat.magnitude_calculator import MagCalculatorTable
107 return UseShortFilterNames(
108 MagCalculatorTable(
109 filter_full_names,
110 filepath_sed_integ,
111 copy_to_cwd=par.copy_template_int_tables_to_cwd,
112 ),
113 par.filters_full_names,
114 )
117MAGNITUDES_CALCULATOR = {
118 "direct": get_magnitude_calculator_direct,
119 "table": get_magnitude_calculator_table,
120}
123class Plugin(BasePlugin):
124 """
125 Generate a random catalog of galaxies with magnitudes in multiple bands.
126 """
128 def check_n_gal_prior(self, par):
129 """
130 Check if the number of galaxies is inside the prior range,
131 even before rendering, to remove extreme values.
132 """
134 if hasattr(par, "galaxy_count_prior"):
135 app_mag = self.ctx.galaxies.int_magnitude_dict[
136 par.galaxy_count_prior["band"]
137 ]
138 n_gal_ = np.count_nonzero(app_mag < par.galaxy_count_prior["mag_max"])
139 n_gal_ /= par.ngal_multiplier
140 LOGGER.info(
141 (
142 "Allowed number galaxies with int_mag<{} per tile [{},{}],"
143 " computed number: {}".format
144 )(
145 par.galaxy_count_prior["mag_max"],
146 par.galaxy_count_prior["n_min"],
147 par.galaxy_count_prior["n_max"],
148 n_gal_,
149 )
150 )
151 if (n_gal_ < par.galaxy_count_prior["n_min"]) or (
152 n_gal_ > par.galaxy_count_prior["n_max"]
153 ):
154 raise galaxy_sampling_util.UCatNumGalError(
155 "too many or too few galaxies"
156 )
158 def check_max_mem_error(self, par, max_mem_hard_limit=10000):
159 """
160 Check if the catalog does not exceed allowed memory.
161 Prevents job crashes on clusters.
162 """
164 mem_mb_current = utils.memory_usage_psutil()
165 if mem_mb_current > max_mem_hard_limit:
166 raise galaxy_sampling_util.UCatNumGalError(
167 "The sample_galaxies process is taking too much memory:"
168 f" mem_mb_current={mem_mb_current},"
169 f" max_mem_hard_limit={max_mem_hard_limit}"
170 )
172 def __call__(self):
173 par = self.ctx.parameters
175 # Cosmology
176 cosmo = PyCosmo.build()
177 cosmo.set(h=par.h, omega_m=par.omega_m)
179 if par.sampling_mode == "wcs":
180 LOGGER.info("Sampling galaxies based on RA/DEC and pixel scale")
181 # Healpix pixelization
182 w = coordinate_util.wcs_from_parameters(par)
183 self.ctx.pixels = coordinate_util.get_healpix_pixels(
184 par.nside_sampling, w, par.size_x, par.size_y
185 )
186 if len(self.ctx.pixels) < 15:
187 LOGGER.warning(
188 f"Only {len(self.ctx.pixels)} healpy pixels in the footprint,"
189 " consider increasing the nside_sampling"
190 )
191 elif par.sampling_mode == "healpix":
192 LOGGER.info("Sampling galaxies based on healpix pixels")
193 self.ctx.pixels = coordinate_util.get_healpix_pixels_from_map(par)
194 w = None
195 else:
196 raise ValueError(
197 f"Unknown sampling mode: {par.sampling_mode}, must be wcs or healpix"
198 )
199 self.ctx.pixarea = hp.nside2pixarea(par.nside_sampling, degrees=False)
201 # Magnitude calculator
202 all_filters = np.unique(par.filters + [par.lum_fct_filter_band])
204 # backward compatibility - check if full filter names are set
205 if not hasattr(par, "filters_full_names"):
206 par.filters_full_names = filters_util.get_default_full_filter_names(
207 all_filters
208 )
209 warnings.warn(
210 "setting filters to default, this will cause problems if you work"
211 " with filters from different cameras in the same band",
212 stacklevel=1,
213 )
215 # get magnitude calculators (reload cache to avoid memory leaks and save memory)
216 mag_calc = MAGNITUDES_CALCULATOR[par.magnitude_calculation](
217 filter_names=all_filters, par=par
218 )
219 n_templates = mag_calc.n_templates
220 # Cut in z - M - plane & boundaries
221 z_m_intp = galaxy_sampling_util.intp_z_m_cut(cosmo, mag_calc, par)
223 # Initialize galaxy catalog
224 self.ctx.galaxies = galaxy_sampling_util.Catalog()
225 self.ctx.galaxies.columns = [
226 "id",
227 "z",
228 "template_coeffs",
229 "template_coeffs_abs",
230 "abs_mag_lumfun",
231 "galaxy_type",
232 "excess_b_v",
233 ]
235 # Columns modified inside loop
236 loop_cols = [
237 "z",
238 "template_coeffs",
239 "template_coeffs_abs",
240 "abs_mag_lumfun",
241 "galaxy_type",
242 "excess_b_v",
243 ]
244 if w is not None:
245 loop_cols += ["x", "y"]
246 self.ctx.galaxies.columns += ["x", "y"]
247 else:
248 loop_cols += ["ra", "dec"]
249 self.ctx.galaxies.columns += ["ra", "dec"]
250 for c in loop_cols:
251 setattr(self.ctx.galaxies, c, [])
253 # set up luminosity functions
254 lum_funcs = initialize_luminosity_functions(
255 par, cosmo=cosmo, pixarea=self.ctx.pixarea, z_m_intp=z_m_intp
256 )
258 # Extinction
259 extinction_eval = ExtinctionMapEvaluator(par)
261 # Helper function to compute templates, extinction
262 for g in par.galaxy_types:
263 n_gal_type = 0
264 n_gal_type_max = getattr(par, f"n_gal_max_{g}")
265 max_reached = False
267 for i in LOGGER.progressbar(
268 range(len(self.ctx.pixels)),
269 desc=f"getting {g:<4s} galaxies for healpix pixels",
270 at_level="debug",
271 ):
272 # Sample absolute mag vs redshift from luminosity function
273 abs_mag, z = lum_funcs[g].sample_z_mabs_and_apply_cut(
274 seed_ngal=par.seed
275 + self.ctx.pixels[i]
276 + par.gal_num_seed_offset
277 + SEED_OFFSET_LUMFUN * i,
278 seed_lumfun=par.seed
279 + self.ctx.pixels[i]
280 + par.gal_lum_fct_seed_offset
281 + SEED_OFFSET_LUMFUN * i,
282 n_gal_max=n_gal_type_max,
283 )
285 # Positions
286 np.random.seed(par.seed + self.ctx.pixels[i] + par.gal_dist_seed_offset)
288 # x and y for wcs, ra and dec for healpix model
289 x, y = sample_position_uniform(
290 len(z), w, self.ctx.pixels[i], par.nside_sampling
291 )
293 # Make the catalog precision already here to avoid inconsistencies in
294 # the selection
295 x = x.astype(par.catalog_precision)
296 y = y.astype(par.catalog_precision)
298 (
299 template_coeffs,
300 template_coeffs_abs,
301 excess_b_v,
302 app_mag_ref,
303 ) = compute_templates_extinction_appmag_for_galaxies(
304 galaxy_type=g,
305 par=par,
306 n_templates=n_templates,
307 cosmo=cosmo,
308 w=w,
309 redshifts=z,
310 absmags=abs_mag,
311 x_pixel=x,
312 y_pixel=y,
313 mag_calc=mag_calc,
314 extinction_eval=extinction_eval,
315 )
317 # Reject galaxies outside set magnitude range
318 select_mag_range = (app_mag_ref >= par.gals_mag_min) & (
319 app_mag_ref <= par.gals_mag_max
320 )
321 if w is not None:
322 select_pos_range = in_pos(x, y, par)
323 else:
324 select_pos_range = np.ones_like(x, dtype=bool)
325 select = select_mag_range & select_pos_range
326 n_gal = np.count_nonzero(select)
327 n_gal_type += n_gal
329 # store
330 if w is not None:
331 self.ctx.galaxies.x.append(x[select].astype(par.catalog_precision))
332 self.ctx.galaxies.y.append(y[select].astype(par.catalog_precision))
333 else:
334 self.ctx.galaxies.ra.append(x[select].astype(par.catalog_precision))
335 self.ctx.galaxies.dec.append(
336 y[select].astype(par.catalog_precision)
337 )
338 self.ctx.galaxies.z.append(z[select].astype(par.catalog_precision))
339 self.ctx.galaxies.template_coeffs.append(
340 template_coeffs[select].astype(par.catalog_precision)
341 )
342 self.ctx.galaxies.template_coeffs_abs.append(
343 template_coeffs_abs[select].astype(par.catalog_precision)
344 )
345 self.ctx.galaxies.abs_mag_lumfun.append(
346 abs_mag[select].astype(par.catalog_precision)
347 )
348 self.ctx.galaxies.galaxy_type.append(
349 np.ones(n_gal, dtype=np.ushort) * lum_funcs[g].galaxy_type
350 )
351 self.ctx.galaxies.excess_b_v.append(
352 excess_b_v[select].astype(par.catalog_precision)
353 )
355 # see if number of galaxies is OK
356 if n_gal_type > n_gal_type_max * par.ngal_multiplier:
357 max_reached = True
358 if par.raise_max_num_gal_error:
359 raise galaxy_sampling_util.UCatNumGalError(
360 "exceeded number of"
361 f" {g} galaxies {n_gal_type}>{n_gal_type_max}"
362 )
363 else:
364 break
366 LOGGER.info(
367 f"lumfun={g} n_gals={n_gal_type} maximum number of galaxies"
368 f" reached={max_reached} ({n_gal_type_max})"
369 )
371 # Common code shared between clustering and uniform position methods
373 # check memory footprint
374 self.check_max_mem_error(par)
376 # Concatenate columns
377 for c in loop_cols:
378 setattr(self.ctx.galaxies, c, np.concatenate(getattr(self.ctx.galaxies, c)))
380 # Calculate requested intrinsic apparent and absolute magnitudes
381 self.ctx.galaxies.int_magnitude_dict = mag_calc(
382 redshifts=self.ctx.galaxies.z,
383 excess_b_v=self.ctx.galaxies.excess_b_v,
384 coeffs=self.ctx.galaxies.template_coeffs,
385 filter_names=par.filters,
386 )
388 self.ctx.galaxies.abs_magnitude_dict = mag_calc(
389 redshifts=np.zeros_like(self.ctx.galaxies.z),
390 excess_b_v=np.zeros_like(self.ctx.galaxies.excess_b_v),
391 coeffs=self.ctx.galaxies.template_coeffs_abs,
392 filter_names=par.filters,
393 )
395 # Raise error is the number of galaxies per tile is too high or too low
396 self.check_n_gal_prior(par)
398 # Set apparent (lensed) magnitudes, for now equal to intrinsic apparent
399 # magnitudes
400 self.ctx.galaxies.magnitude_dict = dict()
401 for band, mag in self.ctx.galaxies.int_magnitude_dict.items():
402 self.ctx.galaxies.magnitude_dict[band] = mag.copy()
404 # Number of galaxies and id
405 self.ctx.numgalaxies = self.ctx.galaxies.z.size
406 self.ctx.galaxies.id = np.arange(self.ctx.numgalaxies)
408 # Backward compatibility
409 self.ctx.galaxies.blue_red = np.ones(len(self.ctx.galaxies.z), dtype=np.ushort)
410 self.ctx.galaxies.blue_red[
411 self.ctx.galaxies.galaxy_type == lum_funcs["blue"].galaxy_type
412 ] = 1
413 self.ctx.galaxies.blue_red[
414 self.ctx.galaxies.galaxy_type == lum_funcs["red"].galaxy_type
415 ] = 0
417 LOGGER.info(
418 f"galaxy counts n_total={self.ctx.numgalaxies}"
419 f" mem_mb_current={utils.memory_usage_psutil():5.1f}"
420 )
421 try:
422 del mag_calc.func.templates_int_table_dict
423 del mag_calc.func.z_grid
424 del mag_calc.func.excess_b_v_grid
425 del mag_calc.func
426 except Exception:
427 pass
428 # profile.print_stats(output_unit=1)
430 def __str__(self):
431 return "sample gal photo"
434def compute_templates_extinction_appmag_for_galaxies(
435 galaxy_type,
436 par,
437 n_templates,
438 cosmo,
439 w,
440 redshifts,
441 absmags,
442 x_pixel,
443 y_pixel,
444 mag_calc,
445 extinction_eval,
446):
447 template_coeffs_abs = sample_template_coeff_lumfuncs(
448 par=par, redshift_z={galaxy_type: redshifts}, n_templates=n_templates
449 )[galaxy_type]
451 # Calculate absolute magnitudes according to coefficients and adjust
452 # coefficients according to drawn magnitudes
453 mag_z0 = mag_calc(
454 redshifts=np.zeros_like(redshifts),
455 excess_b_v=np.zeros_like(redshifts),
456 coeffs=template_coeffs_abs,
457 filter_names=[par.lum_fct_filter_band],
458 )
460 template_coeffs_abs *= np.expand_dims(
461 10 ** (0.4 * (mag_z0[par.lum_fct_filter_band] - absmags)), -1
462 )
464 # Transform to apparent coefficients
465 lum_dist = galaxy_sampling_util.apply_pycosmo_distfun(
466 cosmo.background.dist_lum_a, redshifts
467 )
468 template_coeffs = template_coeffs_abs * np.expand_dims(
469 (10e-6 / lum_dist) ** 2 / (1 + redshifts), -1
470 )
471 excess_b_v = extinction_eval(w, x_pixel, y_pixel)
472 # TODO: fix this in the already in the creation that excess_b_v is always
473 # array, even when n_gal=1
474 if len(redshifts) == 1:
475 excess_b_v = np.array([excess_b_v])
477 # Calculate apparent reference band magnitude
478 app_mag_ref = mag_calc(
479 redshifts=redshifts,
480 excess_b_v=excess_b_v,
481 coeffs=template_coeffs,
482 filter_names=[par.reference_band],
483 )[par.reference_band]
485 return template_coeffs, template_coeffs_abs, excess_b_v, app_mag_ref
488def in_pos(x, y, par):
489 return (x > 0) & (x < par.size_x) & (y > 0) & (y < par.size_y)