Coverage for src/galsbi/galsbi.py: 99%
160 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-26 10:41 +0000
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-26 10:41 +0000
1# Copyright (C) 2024 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
4# created: Thu Aug 08 2024
6import contextlib
7import importlib
9import h5py
10import numpy as np
11from astropy.io import fits
12from astropy.table import Table
13from cosmic_toolbox import arraytools as at
14from cosmic_toolbox import logger
15from ufig import run_util
17from . import citations, load
19LOGGER = logger.get_logger(__name__)
22class GalSBI:
23 """
24 This class is the main interface to the model. It provides methods to generate
25 mock galaxy catalogs and to cite the model.
26 """
28 def __init__(self, name, verbosity="info"):
29 """
30 :param name: name of the model to use
31 :param verbosity: verbosity level of the logger, either "debug", "info",
32 "warning", "error" or "critical"
33 """
34 self.name = name
35 self.mode = None
36 self.filters = None
37 self.verbosity = verbosity
39 def generate_catalog(
40 self,
41 mode="intrinsic",
42 config_file=None,
43 model_index=0,
44 file_name="GalSBI_sim",
45 verbosity=None,
46 **kwargs,
47 ):
48 """
49 Generates a mock galaxy catalog using the model and configuration specified. The
50 parameter model_index is used to select a specific set of model parameters from
51 the ABC posterior. If a list of model parameters is provided, catalogs are
52 generated for each set of parameters. The saved catalogs and images are named
53 according to the file_name and model_index.
55 Names of the files
56 ------------------
57 - Intrinsic ucat galaxy catalog: f"{file_name}_{index}_{band}_ucat.gal.cat"
58 - Intrinsic ucat star catalog: f"{file_name}_{index}_{band}_ucat.star.cat"
59 - Output catalog: f"{file_name}_{index}_{band}_se.cat"
60 - Output image: f"{file_name}_{index}_{band}_image.fits"
61 - Segmentation map: f"{file_name}_{index}_{band}_se_seg.h5"
62 - Background map: f"{file_name}_{index}_{band}_se_bkg.h5"
63 - SED catalog: f"{file_name}_{index}_sed.cat"
65 :param mode: mode to use for generating the catalog, either "intrinsic", "emu",
66 "image", "image+SE", "config_file"
67 :param config_file: dictionary or path to a configuration file to use for
68 generating the catalog (only used if mode="config_file")
69 :param model_index: index of the model parameters to use for generating the
70 catalog
71 :param file_name: filename of the catalog and images to generate
72 :param verbosity: verbosity level of the logger, either "debug", "info",
73 "warning", "error" or "critical"
74 :param kwargs: additional keyword arguments to pass to the workflow (overwrites
75 the values from the model parameters and config file)
76 """
77 if verbosity is not None:
78 self.verbosity = verbosity
79 logger.set_all_loggers_level(self.verbosity)
80 self.mode = mode
81 model_parameters = load.load_abc_posterior(self.name)
82 config = load.load_config(self.name, mode, config_file)
84 if isinstance(model_index, int):
85 model_index = [model_index]
86 for index in model_index:
87 LOGGER.info(
88 "Generating catalog for model"
89 f" {self.name} and mode {mode} with index {index}"
90 )
91 kwargs["galaxy_catalog_name_format"] = (
92 f"{file_name}_{index}_{{}}{{}}_ucat.gal.cat"
93 )
94 kwargs["star_catalog_name_format"] = (
95 f"{file_name}_{index}_{{}}{{}}_ucat.star.cat"
96 )
97 kwargs["sextractor_forced_photo_catalog_name_format"] = (
98 f"{file_name}_{index}_{{}}{{}}_se.cat"
99 )
100 kwargs["galaxy_sed_catalog_name"] = f"{file_name}_{index}_sed.cat"
101 kwargs["image_name_format"] = f"{file_name}_{index}_{{}}{{}}_image.fits"
102 kwargs["tile_name"] = ""
103 self.file_name = file_name
104 self.catalog_name = file_name # for backward compatibility
105 self._run(config, model_parameters[index], **kwargs)
107 __call__ = generate_catalog
109 def _run(self, config, model_parameters, **kwargs):
110 """
111 Runs the workflow with the given configuration and model parameters
113 :param config: configuration to use for generating the catalog
114 :param model_parameters: model parameters to use for generating the catalog
115 :param kwargs: additional keyword arguments to pass to the workflow (overwrites
116 the values from the model parameters and config file)
117 """
118 kargs = {}
119 for col in model_parameters.dtype.names:
120 kargs[col] = model_parameters[col]
121 if ("moffat_beta1" in model_parameters.dtype.names) and (
122 "moffat_beta2" in model_parameters.dtype.names
123 ):
124 kargs["psf_beta"] = [
125 model_parameters["moffat_beta1"][0],
126 model_parameters["moffat_beta2"][0],
127 ]
128 kargs.update(kwargs)
129 if "filters" in kargs:
130 self.filters = kargs["filters"]
131 else:
132 config_module = importlib.import_module(config)
133 self.filters = config_module.filters
135 self.ctx = run_util.run_ufig_from_config(config, **kargs)
137 def cite(self):
138 """
139 Prints all the papers that should be cited when using the configuration
140 specified
141 """
142 print("\033[1mPlease cite the following papers\033[0m")
143 print("=================================")
144 print("\033[1mFor using the GalSBI model:\033[0m")
145 citations.cite_galsbi_release()
146 print("\033[1mFor using the galsbi python package:\033[0m")
147 citations.cite_code_release(self.mode)
148 print("")
150 print(
151 "\033[1mFor the galaxy population model and redshift distribution:\033[0m"
152 )
153 citations.cite_abc_posterior(self.name)
154 print("")
155 print("Example:")
156 print("--------")
157 print(
158 "We use the GalSBI framework (PAPERS GalSBI release) to generate mock"
159 " galaxy catalogs. The galaxy population model corresponds to the"
160 " posterior from (PAPER model). (...) "
161 "Acknowledgements: We acknowledge the use of the following software:"
162 "(numpy), (scipy), (PAPERS code release), (...)"
163 )
165 def load_catalogs(self, output_format="rec", model_index=0, combine=False):
166 """
167 Loads the catalogs generated by the model.
169 :param output_format: format of the output, either "rec", "df" or "fits"
170 :param model_index: index of the model parameters to use for loading the
171 catalogs
172 :param combine: if True, combines the catalogs from all bands into a single
173 catalog
174 :return: catalogs in the specified format
175 """
176 if self.filters is None:
177 raise RuntimeError("please generate catalogs first")
179 if output_format == "rec":
180 convert = lambda x: x # noqa: E731
181 elif output_format == "df":
182 convert = at.rec2pd
183 elif output_format == "fits":
184 convert = Table
185 else:
186 raise ValueError(f"Unknown output format {output_format}")
188 output = {}
189 for band in self.filters:
190 catalog_name = f"{self.file_name}_{model_index}_{band}_ucat.gal.cat"
191 with contextlib.suppress(FileNotFoundError):
192 output[f"ucat galaxies {band}"] = at.load_hdf(catalog_name)
193 catalog_name = f"{self.file_name}_{model_index}_{band}_ucat.star.cat"
194 with contextlib.suppress(FileNotFoundError):
195 output[f"ucat stars {band}"] = at.load_hdf(catalog_name)
196 catalog_name = f"{self.file_name}_{model_index}_{band}_se.cat"
197 with contextlib.suppress(FileNotFoundError):
198 output[f"sextractor {band}"] = at.load_hdf(catalog_name)
199 with contextlib.suppress(FileNotFoundError):
200 sed_catalog_name = f"{self.file_name}_{model_index}_sed.cat"
201 with h5py.File(sed_catalog_name, "r") as fh5:
202 output["sed"] = fh5["data"][:]
203 output["restframe_wavelength_in_A"] = fh5["restframe_wavelength_in_A"][
204 :
205 ]
207 if len(output) == 0:
208 LOGGER.warning(
209 "No catalogs found. Did you already generate catalogs? Does the "
210 "model_index match the one used for generating the catalogs?"
211 )
212 if not combine:
213 catalogs = {
214 key: convert(value) if key != "restframe_wavelength_in_A" else value
215 for key, value in output.items()
216 }
217 return catalogs
219 combined_catalogs = self._build_combined_catalogs(output)
220 return {
221 key: value if key == "restframe_wavelength_in_A" else convert(value)
222 for key, value in combined_catalogs.items()
223 }
225 def load_images(self, model_index=0):
226 """
227 Loads the images generated by the model. This include the actual image,
228 the segmentation map and the background map.
230 :param model_index: index of the model parameters to use for loading the images
231 :return: images as numpy arrays
232 """
233 output = {}
234 for band in self.filters:
235 image_name = f"{self.file_name}_{model_index}_{band}_image.fits"
236 try:
237 hdul = fits.open(image_name)
238 image = hdul[0].data
239 hdul.close()
240 output[f"image {band}"] = image
241 except FileNotFoundError:
242 pass
243 segmap_name = f"{self.file_name}_{model_index}_{band}_se_seg.h5"
244 with contextlib.suppress(FileNotFoundError):
245 output[f"segmentation {band}"] = at.load_hdf_cols(segmap_name)[
246 "SEGMENTATION"
247 ]
248 bkgmap_name = f"{self.file_name}_{model_index}_{band}_se_bkg.h5"
249 with contextlib.suppress(FileNotFoundError):
250 output[f"background {band}"] = at.load_hdf_cols(bkgmap_name)[
251 "BACKGROUND"
252 ]
253 return output
255 def _build_combined_catalogs(self, catalogs):
256 band_dep_params = ["int_mag", "mag", "abs_mag", "bkg_noise_amp"]
257 combined_catalogs = {}
258 filter = self.filters[0]
260 if f"ucat galaxies {filter}" in catalogs:
261 new_cat = {}
262 for f in self.filters:
263 cat = catalogs[f"ucat galaxies {f}"]
264 for par in cat.dtype.names:
265 if par not in band_dep_params:
266 new_cat[par] = cat[par]
267 else:
268 new_cat[f"{par} {f}"] = cat[par]
269 if "sed" in catalogs:
270 sed_catalog = catalogs["sed"]
271 new_cat["sed"] = sed_catalog["sed"]
272 combined_catalogs["ucat galaxies"] = at.dict2rec(new_cat)
274 if f"ucat stars {filter}" in catalogs:
275 new_cat = {}
276 for f in self.filters:
277 cat = catalogs[f"ucat stars {f}"]
278 for par in cat.dtype.names:
279 if par not in band_dep_params:
280 new_cat[par] = cat[par]
281 else:
282 new_cat[f"{par} {f}"] = cat[par]
283 combined_catalogs["ucat stars"] = at.dict2rec(new_cat)
285 if f"sextractor {filter}" in catalogs:
286 band_ind_params = [
287 "dec",
288 "ra",
289 "z",
290 "e1",
291 "e2",
292 "r50",
293 "r50_arcsec",
294 "r50_phys",
295 "sersic_n",
296 "galaxy_type",
297 "id",
298 "x",
299 "y",
300 "star_gal",
301 ]
302 new_cat = {}
303 for f in self.filters:
304 cat = catalogs[f"sextractor {f}"]
305 for par in cat.dtype.names:
306 if par in band_ind_params:
307 new_cat[par] = cat[par]
308 else:
309 new_cat[f"{par} {f}"] = cat[par]
310 if "sed" in catalogs:
311 sed_catalog = catalogs["sed"]
312 ids = new_cat["id"]
313 # Match SEDs to IDs and handle non-matched IDs
314 matched_gals = new_cat["galaxy_type"] >= 0
315 sed_data = np.full(
316 (len(ids), sed_catalog["sed"].shape[1]),
317 np.nan,
318 dtype=sed_catalog["sed"].dtype,
319 )
320 sed_data[matched_gals] = sed_catalog["sed"][ids[matched_gals]]
322 new_cat["sed"] = sed_data
323 combined_catalogs["sextractor"] = at.dict2rec(new_cat)
324 if "restframe_wavelength_in_A" in catalogs:
325 combined_catalogs["restframe_wavelength_in_A"] = catalogs[
326 "restframe_wavelength_in_A"
327 ]
328 return combined_catalogs