Coverage for src/galsbi/galsbi.py: 99%

160 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-08-26 10:41 +0000

1# Copyright (C) 2024 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4# created: Thu Aug 08 2024 

5 

6import contextlib 

7import importlib 

8 

9import h5py 

10import numpy as np 

11from astropy.io import fits 

12from astropy.table import Table 

13from cosmic_toolbox import arraytools as at 

14from cosmic_toolbox import logger 

15from ufig import run_util 

16 

17from . import citations, load 

18 

19LOGGER = logger.get_logger(__name__) 

20 

21 

22class GalSBI: 

23 """ 

24 This class is the main interface to the model. It provides methods to generate 

25 mock galaxy catalogs and to cite the model. 

26 """ 

27 

28 def __init__(self, name, verbosity="info"): 

29 """ 

30 :param name: name of the model to use 

31 :param verbosity: verbosity level of the logger, either "debug", "info", 

32 "warning", "error" or "critical" 

33 """ 

34 self.name = name 

35 self.mode = None 

36 self.filters = None 

37 self.verbosity = verbosity 

38 

39 def generate_catalog( 

40 self, 

41 mode="intrinsic", 

42 config_file=None, 

43 model_index=0, 

44 file_name="GalSBI_sim", 

45 verbosity=None, 

46 **kwargs, 

47 ): 

48 """ 

49 Generates a mock galaxy catalog using the model and configuration specified. The 

50 parameter model_index is used to select a specific set of model parameters from 

51 the ABC posterior. If a list of model parameters is provided, catalogs are 

52 generated for each set of parameters. The saved catalogs and images are named 

53 according to the file_name and model_index. 

54 

55 Names of the files 

56 ------------------ 

57 - Intrinsic ucat galaxy catalog: f"{file_name}_{index}_{band}_ucat.gal.cat" 

58 - Intrinsic ucat star catalog: f"{file_name}_{index}_{band}_ucat.star.cat" 

59 - Output catalog: f"{file_name}_{index}_{band}_se.cat" 

60 - Output image: f"{file_name}_{index}_{band}_image.fits" 

61 - Segmentation map: f"{file_name}_{index}_{band}_se_seg.h5" 

62 - Background map: f"{file_name}_{index}_{band}_se_bkg.h5" 

63 - SED catalog: f"{file_name}_{index}_sed.cat" 

64 

65 :param mode: mode to use for generating the catalog, either "intrinsic", "emu", 

66 "image", "image+SE", "config_file" 

67 :param config_file: dictionary or path to a configuration file to use for 

68 generating the catalog (only used if mode="config_file") 

69 :param model_index: index of the model parameters to use for generating the 

70 catalog 

71 :param file_name: filename of the catalog and images to generate 

72 :param verbosity: verbosity level of the logger, either "debug", "info", 

73 "warning", "error" or "critical" 

74 :param kwargs: additional keyword arguments to pass to the workflow (overwrites 

75 the values from the model parameters and config file) 

76 """ 

77 if verbosity is not None: 

78 self.verbosity = verbosity 

79 logger.set_all_loggers_level(self.verbosity) 

80 self.mode = mode 

81 model_parameters = load.load_abc_posterior(self.name) 

82 config = load.load_config(self.name, mode, config_file) 

83 

84 if isinstance(model_index, int): 

85 model_index = [model_index] 

86 for index in model_index: 

87 LOGGER.info( 

88 "Generating catalog for model" 

89 f" {self.name} and mode {mode} with index {index}" 

90 ) 

91 kwargs["galaxy_catalog_name_format"] = ( 

92 f"{file_name}_{index}_{{}}{{}}_ucat.gal.cat" 

93 ) 

94 kwargs["star_catalog_name_format"] = ( 

95 f"{file_name}_{index}_{{}}{{}}_ucat.star.cat" 

96 ) 

97 kwargs["sextractor_forced_photo_catalog_name_format"] = ( 

98 f"{file_name}_{index}_{{}}{{}}_se.cat" 

99 ) 

100 kwargs["galaxy_sed_catalog_name"] = f"{file_name}_{index}_sed.cat" 

101 kwargs["image_name_format"] = f"{file_name}_{index}_{{}}{{}}_image.fits" 

102 kwargs["tile_name"] = "" 

103 self.file_name = file_name 

104 self.catalog_name = file_name # for backward compatibility 

105 self._run(config, model_parameters[index], **kwargs) 

106 

107 __call__ = generate_catalog 

108 

109 def _run(self, config, model_parameters, **kwargs): 

110 """ 

111 Runs the workflow with the given configuration and model parameters 

112 

113 :param config: configuration to use for generating the catalog 

114 :param model_parameters: model parameters to use for generating the catalog 

115 :param kwargs: additional keyword arguments to pass to the workflow (overwrites 

116 the values from the model parameters and config file) 

117 """ 

118 kargs = {} 

119 for col in model_parameters.dtype.names: 

120 kargs[col] = model_parameters[col] 

121 if ("moffat_beta1" in model_parameters.dtype.names) and ( 

122 "moffat_beta2" in model_parameters.dtype.names 

123 ): 

124 kargs["psf_beta"] = [ 

125 model_parameters["moffat_beta1"][0], 

126 model_parameters["moffat_beta2"][0], 

127 ] 

128 kargs.update(kwargs) 

129 if "filters" in kargs: 

130 self.filters = kargs["filters"] 

131 else: 

132 config_module = importlib.import_module(config) 

133 self.filters = config_module.filters 

134 

135 self.ctx = run_util.run_ufig_from_config(config, **kargs) 

136 

137 def cite(self): 

138 """ 

139 Prints all the papers that should be cited when using the configuration 

140 specified 

141 """ 

142 print("\033[1mPlease cite the following papers\033[0m") 

143 print("=================================") 

144 print("\033[1mFor using the GalSBI model:\033[0m") 

145 citations.cite_galsbi_release() 

146 print("\033[1mFor using the galsbi python package:\033[0m") 

147 citations.cite_code_release(self.mode) 

148 print("") 

149 

150 print( 

151 "\033[1mFor the galaxy population model and redshift distribution:\033[0m" 

152 ) 

153 citations.cite_abc_posterior(self.name) 

154 print("") 

155 print("Example:") 

156 print("--------") 

157 print( 

158 "We use the GalSBI framework (PAPERS GalSBI release) to generate mock" 

159 " galaxy catalogs. The galaxy population model corresponds to the" 

160 " posterior from (PAPER model). (...) " 

161 "Acknowledgements: We acknowledge the use of the following software:" 

162 "(numpy), (scipy), (PAPERS code release), (...)" 

163 ) 

164 

165 def load_catalogs(self, output_format="rec", model_index=0, combine=False): 

166 """ 

167 Loads the catalogs generated by the model. 

168 

169 :param output_format: format of the output, either "rec", "df" or "fits" 

170 :param model_index: index of the model parameters to use for loading the 

171 catalogs 

172 :param combine: if True, combines the catalogs from all bands into a single 

173 catalog 

174 :return: catalogs in the specified format 

175 """ 

176 if self.filters is None: 

177 raise RuntimeError("please generate catalogs first") 

178 

179 if output_format == "rec": 

180 convert = lambda x: x # noqa: E731 

181 elif output_format == "df": 

182 convert = at.rec2pd 

183 elif output_format == "fits": 

184 convert = Table 

185 else: 

186 raise ValueError(f"Unknown output format {output_format}") 

187 

188 output = {} 

189 for band in self.filters: 

190 catalog_name = f"{self.file_name}_{model_index}_{band}_ucat.gal.cat" 

191 with contextlib.suppress(FileNotFoundError): 

192 output[f"ucat galaxies {band}"] = at.load_hdf(catalog_name) 

193 catalog_name = f"{self.file_name}_{model_index}_{band}_ucat.star.cat" 

194 with contextlib.suppress(FileNotFoundError): 

195 output[f"ucat stars {band}"] = at.load_hdf(catalog_name) 

196 catalog_name = f"{self.file_name}_{model_index}_{band}_se.cat" 

197 with contextlib.suppress(FileNotFoundError): 

198 output[f"sextractor {band}"] = at.load_hdf(catalog_name) 

199 with contextlib.suppress(FileNotFoundError): 

200 sed_catalog_name = f"{self.file_name}_{model_index}_sed.cat" 

201 with h5py.File(sed_catalog_name, "r") as fh5: 

202 output["sed"] = fh5["data"][:] 

203 output["restframe_wavelength_in_A"] = fh5["restframe_wavelength_in_A"][ 

204 : 

205 ] 

206 

207 if len(output) == 0: 

208 LOGGER.warning( 

209 "No catalogs found. Did you already generate catalogs? Does the " 

210 "model_index match the one used for generating the catalogs?" 

211 ) 

212 if not combine: 

213 catalogs = { 

214 key: convert(value) if key != "restframe_wavelength_in_A" else value 

215 for key, value in output.items() 

216 } 

217 return catalogs 

218 

219 combined_catalogs = self._build_combined_catalogs(output) 

220 return { 

221 key: value if key == "restframe_wavelength_in_A" else convert(value) 

222 for key, value in combined_catalogs.items() 

223 } 

224 

225 def load_images(self, model_index=0): 

226 """ 

227 Loads the images generated by the model. This include the actual image, 

228 the segmentation map and the background map. 

229 

230 :param model_index: index of the model parameters to use for loading the images 

231 :return: images as numpy arrays 

232 """ 

233 output = {} 

234 for band in self.filters: 

235 image_name = f"{self.file_name}_{model_index}_{band}_image.fits" 

236 try: 

237 hdul = fits.open(image_name) 

238 image = hdul[0].data 

239 hdul.close() 

240 output[f"image {band}"] = image 

241 except FileNotFoundError: 

242 pass 

243 segmap_name = f"{self.file_name}_{model_index}_{band}_se_seg.h5" 

244 with contextlib.suppress(FileNotFoundError): 

245 output[f"segmentation {band}"] = at.load_hdf_cols(segmap_name)[ 

246 "SEGMENTATION" 

247 ] 

248 bkgmap_name = f"{self.file_name}_{model_index}_{band}_se_bkg.h5" 

249 with contextlib.suppress(FileNotFoundError): 

250 output[f"background {band}"] = at.load_hdf_cols(bkgmap_name)[ 

251 "BACKGROUND" 

252 ] 

253 return output 

254 

255 def _build_combined_catalogs(self, catalogs): 

256 band_dep_params = ["int_mag", "mag", "abs_mag", "bkg_noise_amp"] 

257 combined_catalogs = {} 

258 filter = self.filters[0] 

259 

260 if f"ucat galaxies {filter}" in catalogs: 

261 new_cat = {} 

262 for f in self.filters: 

263 cat = catalogs[f"ucat galaxies {f}"] 

264 for par in cat.dtype.names: 

265 if par not in band_dep_params: 

266 new_cat[par] = cat[par] 

267 else: 

268 new_cat[f"{par} {f}"] = cat[par] 

269 if "sed" in catalogs: 

270 sed_catalog = catalogs["sed"] 

271 new_cat["sed"] = sed_catalog["sed"] 

272 combined_catalogs["ucat galaxies"] = at.dict2rec(new_cat) 

273 

274 if f"ucat stars {filter}" in catalogs: 

275 new_cat = {} 

276 for f in self.filters: 

277 cat = catalogs[f"ucat stars {f}"] 

278 for par in cat.dtype.names: 

279 if par not in band_dep_params: 

280 new_cat[par] = cat[par] 

281 else: 

282 new_cat[f"{par} {f}"] = cat[par] 

283 combined_catalogs["ucat stars"] = at.dict2rec(new_cat) 

284 

285 if f"sextractor {filter}" in catalogs: 

286 band_ind_params = [ 

287 "dec", 

288 "ra", 

289 "z", 

290 "e1", 

291 "e2", 

292 "r50", 

293 "r50_arcsec", 

294 "r50_phys", 

295 "sersic_n", 

296 "galaxy_type", 

297 "id", 

298 "x", 

299 "y", 

300 "star_gal", 

301 ] 

302 new_cat = {} 

303 for f in self.filters: 

304 cat = catalogs[f"sextractor {f}"] 

305 for par in cat.dtype.names: 

306 if par in band_ind_params: 

307 new_cat[par] = cat[par] 

308 else: 

309 new_cat[f"{par} {f}"] = cat[par] 

310 if "sed" in catalogs: 

311 sed_catalog = catalogs["sed"] 

312 ids = new_cat["id"] 

313 # Match SEDs to IDs and handle non-matched IDs 

314 matched_gals = new_cat["galaxy_type"] >= 0 

315 sed_data = np.full( 

316 (len(ids), sed_catalog["sed"].shape[1]), 

317 np.nan, 

318 dtype=sed_catalog["sed"].dtype, 

319 ) 

320 sed_data[matched_gals] = sed_catalog["sed"][ids[matched_gals]] 

321 

322 new_cat["sed"] = sed_data 

323 combined_catalogs["sextractor"] = at.dict2rec(new_cat) 

324 if "restframe_wavelength_in_A" in catalogs: 

325 combined_catalogs["restframe_wavelength_in_A"] = catalogs[ 

326 "restframe_wavelength_in_A" 

327 ] 

328 return combined_catalogs