Coverage for src/ufig/psf_estimation/save_model.py: 90%

133 statements  

« prev     ^ index     » next       coverage.py v7.10.2, created at 2025-08-07 15:17 +0000

1# Copyright (C) 2025 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4# created: Mon Jul 21 2025 

5 

6 

7import h5py 

8import numpy as np 

9from cosmic_toolbox import arraytools as at 

10from cosmic_toolbox import logger 

11 

12from ufig.psf_estimation import psf_utils 

13 

14LOGGER = logger.get_logger(__file__) 

15HDF5_COMPRESS = {"compression": "gzip", "compression_opts": 9, "shuffle": True} 

16ERR_VAL = 999.0 

17 

18 

19class PSFSave: 

20 """ 

21 Saves PSF models and predictions to HDF5 format. 

22 

23 This class manages: 

24 

25 - Saving PSF models to HDF5 format 

26 - Writing output catalogs with PSF predictions 

27 - Storing diagnostic data and star cubes 

28 - Creating grid predictions for visualization 

29 """ 

30 

31 def __init__(self, config): 

32 """ 

33 Initialize the PSF save utility with configuration parameters. 

34 """ 

35 self.config = config 

36 

37 def save_psf_model( 

38 self, 

39 filepath_out, 

40 processed_data, 

41 cnn_results, 

42 model_results, 

43 filepath_sysmaps=None, 

44 filepath_cat_out=None, 

45 ): 

46 """ 

47 Save the PSF model and predictions to an HDF5 file. 

48 """ 

49 LOGGER.info(f"Saving PSF model to: {filepath_out}") 

50 # Create the main HDF5 model file 

51 self._save_hdf5_model(filepath_out, filepath_sysmaps, model_results) 

52 

53 # Generate and save predictions 

54 predictions = self._generate_predictions( 

55 filepath_out, processed_data, model_results 

56 ) 

57 

58 # Update HDF5 file with predictions 

59 self._save_predictions_to_hdf5(filepath_out, predictions) 

60 

61 if filepath_cat_out is not None: 61 ↛ 71line 61 didn't jump to line 71 because the condition on line 61 was always true

62 self._save_catalog_output( 

63 filepath_cat_out, 

64 processed_data, 

65 cnn_results, 

66 model_results, 

67 predictions, 

68 ) 

69 

70 # Save star cube if requested 

71 if self.config.get("save_star_cube", False): 

72 self._save_star_cube(filepath_cat_out, cnn_results) 

73 

74 LOGGER.info("PSF model saving complete") 

75 

76 def _save_hdf5_model(self, filepath_out, filepath_sysmaps, model_results): 

77 """Save the core PSF model to HDF5 format.""" 

78 LOGGER.debug("Creating HDF5 PSF model file") 

79 

80 with h5py.File(filepath_out, mode="w") as fh5_out: 

81 # Copy systematics maps from input file 

82 if filepath_sysmaps is not None: 82 ↛ 86line 82 didn't jump to line 86 because the condition on line 82 was always true

83 self._copy_systematics_maps(fh5_out, filepath_sysmaps) 

84 

85 # Save fitting settings 

86 settings_fit = model_results["settings_fit"] 

87 for key in settings_fit: 

88 at.replace_hdf5_dataset( 

89 fobj=fh5_out, name=f"settings/{key}", data=settings_fit[key] 

90 ) 

91 

92 # Save polynomial model components 

93 regressor = model_results["regressor"] 

94 col_names_ipt = model_results["col_names_ipt"] 

95 

96 at.replace_hdf5_dataset( 

97 fobj=fh5_out, name="par_names", data=col_names_ipt, **HDF5_COMPRESS 

98 ) 

99 at.replace_hdf5_dataset( 

100 fobj=fh5_out, 

101 name="arr_pointings_polycoeffs", 

102 data=regressor.arr_pointings_polycoeffs, 

103 **HDF5_COMPRESS, 

104 ) 

105 at.replace_hdf5_dataset( 

106 fobj=fh5_out, 

107 name="unseen_pointings", 

108 data=regressor.unseen_pointings, 

109 **HDF5_COMPRESS, 

110 ) 

111 at.replace_hdf5_dataset( 

112 fobj=fh5_out, 

113 name="set_unseen_to_mean", 

114 data=regressor.set_unseen_to_mean, 

115 ) 

116 

117 def _copy_systematics_maps(self, fh5_out, filepath_sys): 

118 """Copy systematics maps from input file.""" 

119 LOGGER.debug("Copying systematics maps") 

120 

121 with h5py.File(filepath_sys, mode="r") as fh5_sys: 

122 for key in fh5_sys: 

123 try: 

124 fh5_out.create_dataset(name=key, data=fh5_sys[key], **HDF5_COMPRESS) 

125 except Exception as err: 

126 LOGGER.debug( 

127 f"Failed to compress {key}: {err}, copying uncompressed" 

128 ) 

129 fh5_sys.copy(key, fh5_out) 

130 

131 # Copy pointing attributes 

132 if "map_pointings" in fh5_sys: 132 ↛ exitline 132 didn't jump to the function exit

133 fh5_out["map_pointings"].attrs["n_pointings"] = fh5_sys[ 

134 "map_pointings" 

135 ].attrs["n_pointings"] 

136 

137 def _generate_predictions(self, filepath_out, processed_data, model_results): 

138 """Generate PSF predictions for various data sets.""" 

139 LOGGER.debug("Generating PSF predictions") 

140 

141 predictions = {} 

142 psfmodel_corr_brighter_fatter = self.config.get("psfmodel_corr_brighter_fatter") 

143 

144 # Predict for training stars 

145 cat_fit = model_results["cat_fit"] 

146 predictions["star_train_psf"], _ = self._predict_for_catalog( 

147 cat_fit, filepath_out, psfmodel_corr_brighter_fatter 

148 ) 

149 

150 # Predict for full input catalog 

151 cat_full = processed_data.get("cat_original", processed_data["cat_gaia"]) 

152 predictions["cat_psf"], predictions["n_exposures"] = self._predict_for_catalog( 

153 cat_full, filepath_out, psfmodel_corr_brighter_fatter 

154 ) 

155 

156 # Generate grid predictions 

157 predictions["grid_psf"] = self._generate_grid_predictions( 

158 filepath_out, 

159 processed_data["image_shape"], 

160 ) 

161 

162 # Add astrometric derivatives if requested 

163 if self.config.get("astrometry_errors", False): 

164 col_names_astrometry_ipt = [ 

165 f"{col}_ipt" for col in ["astrometry_diff_x", "astrometry_diff_y"] 

166 ] 

167 predictions["star_train_psf"] = self._add_model_derivatives( 

168 predictions["star_train_psf"], 

169 filepath_out, 

170 col_names_astrometry_ipt, 

171 psfmodel_corr_brighter_fatter, 

172 ) 

173 predictions["cat_psf"] = self._add_model_derivatives( 

174 predictions["cat_psf"], 

175 filepath_out, 

176 col_names_astrometry_ipt, 

177 psfmodel_corr_brighter_fatter, 

178 ) 

179 

180 return predictions 

181 

182 def _predict_for_catalog(self, cat, filepath_out, psfmodel_corr_brighter_fatter): 

183 """Generate PSF predictions for a catalog.""" 

184 # Import here to avoid circular dependencies 

185 from .psf_predictions import predict_psf_for_catalogue_storing 

186 

187 return predict_psf_for_catalogue_storing( 

188 cat, filepath_out, psfmodel_corr_brighter_fatter 

189 ) 

190 

191 def _generate_grid_predictions(self, filepath_out, img_shape): 

192 """ 

193 Generate PSF predictions on a regular grid. 

194 Brighter-fatter correction can not be applied here as it requires 

195 specific magnitudes and ellipticities which are not available for grid points. 

196 """ 

197 LOGGER.debug("Generating grid predictions") 

198 

199 # Create regular grid across image 

200 x_grid, y_grid = np.meshgrid( 

201 np.arange(0, img_shape[1], 100), np.arange(0, img_shape[0], 100) 

202 ) 

203 x_grid = x_grid.ravel() + 0.5 

204 y_grid = y_grid.ravel() + 0.5 

205 

206 # Create catalog for grid points 

207 grid_cat = np.empty(len(x_grid), dtype=at.get_dtype(["X_IMAGE", "Y_IMAGE"])) 

208 grid_cat["X_IMAGE"] = x_grid 

209 grid_cat["Y_IMAGE"] = y_grid 

210 

211 # Generate predictions 

212 grid_psf, _ = self._predict_for_catalog( 

213 grid_cat, filepath_out, psfmodel_corr_brighter_fatter=None 

214 ) 

215 

216 return grid_psf 

217 

218 def _add_model_derivatives( 

219 self, cat, filepath_out, cols, psfmodel_corr_brighter_fatter, delta=1e-2 

220 ): 

221 """Add spatial derivative predictions.""" 

222 # Import here to avoid circular dependencies 

223 from .psf_predictions import get_model_derivatives 

224 

225 return get_model_derivatives( 

226 cat, filepath_out, cols, psfmodel_corr_brighter_fatter, delta 

227 ) 

228 

229 def _save_predictions_to_hdf5(self, filepath_out, predictions): 

230 """Save predictions to the HDF5 model file.""" 

231 LOGGER.debug("Saving predictions to HDF5 file") 

232 

233 with h5py.File(filepath_out, mode="r+") as fh5_out: 

234 # Save full catalog predictions 

235 at.replace_hdf5_dataset( 

236 fobj=fh5_out, 

237 name="predictions", 

238 data=at.set_storing_dtypes(predictions["cat_psf"]), 

239 **HDF5_COMPRESS, 

240 ) 

241 

242 # Save training star data and predictions 

243 if "star_train_psf" in predictions: 243 ↛ 254line 243 didn't jump to line 254 because the condition on line 243 was always true

244 data = predictions["star_train_psf"] 

245 

246 at.replace_hdf5_dataset( 

247 fobj=fh5_out, 

248 name="star_train_prediction", 

249 data=at.set_storing_dtypes(data), 

250 **HDF5_COMPRESS, 

251 ) 

252 

253 # Save grid predictions 

254 at.replace_hdf5_dataset( 

255 fobj=fh5_out, 

256 name="grid_psf", 

257 data=at.set_storing_dtypes(predictions["grid_psf"]), 

258 **HDF5_COMPRESS, 

259 ) 

260 

261 def _save_catalog_output( 

262 self, filepath_cat_out, processed_data, cnn_results, model_results, predictions 

263 ): 

264 """Save catalog output with PSF predictions.""" 

265 LOGGER.debug(f"Saving catalog output: {filepath_cat_out}") 

266 

267 # Get original catalog 

268 cat = processed_data.get("cat_original", processed_data["cat_gaia"]) 

269 

270 # Define output columns 

271 col_names_cnn_fit = [ 

272 col for col in model_results["col_names_fit"] if col.endswith("_cnn") 

273 ] 

274 col_names_ipt = model_results["col_names_ipt"] 

275 

276 cols_new = ( 

277 col_names_cnn_fit 

278 + col_names_ipt 

279 + [ 

280 "FLAGS_STARS:i4", 

281 "N_EXPOSURES:i2", 

282 "gaia_ra_match", 

283 "gaia_dec_match", 

284 "gaia_id_match:i8", 

285 ] 

286 ) 

287 

288 # Create output catalog 

289 cat_save_psf = at.ensure_cols(cat, names=cols_new) 

290 

291 # Fill with predictions and metadata 

292 self._fill_catalog_output( 

293 cat_save_psf, processed_data, cnn_results, model_results, predictions 

294 ) 

295 

296 # Convert precision if requested 

297 precision = self.config.get("precision", float) 

298 if precision == np.float32: 

299 cat_save_psf = at.rec_float64_to_float32(cat_save_psf) 

300 LOGGER.info("Converted catalog to float32 precision") 

301 

302 # Save catalog 

303 at.save_hdf_cols(filepath_cat_out, cat_save_psf) 

304 LOGGER.info(f"Catalog saved with {len(cat_save_psf)} sources") 

305 

306 def _fill_catalog_output( 

307 self, cat_save_psf, processed_data, cnn_results, model_results, predictions 

308 ): 

309 """Fill catalog output with predictions and metadata.""" 

310 # Get masks and data 

311 cat_gaia = processed_data.get("cat_gaia", None) 

312 flags_gaia = processed_data.get("flags_gaia", None) 

313 flags_all = processed_data.get("flags_all", None) 

314 cat_cnn = cnn_results.get("cat_cnn", None) 

315 flags_cnn = cnn_results.get("flags_cnn", None) 

316 select_gaia = flags_gaia == 0 

317 select_cnn = flags_cnn == 0 

318 select_all = flags_all == 0 

319 

320 # Fill interpolated PSF parameters 

321 col_names_ipt = model_results["col_names_ipt"] 

322 for col_name in col_names_ipt: 

323 if col_name in predictions["cat_psf"].dtype.names: 323 ↛ 322line 323 didn't jump to line 322 because the condition on line 323 was always true

324 cat_save_psf[col_name] = predictions["cat_psf"][col_name] 

325 

326 # Fill CNN predicted parameters for stars used in fitting 

327 if cat_cnn is not None and flags_cnn is not None: 327 ↛ 343line 327 didn't jump to line 343 because the condition on line 327 was always true

328 col_names_cnn_fit = [ 

329 col for col in model_results["col_names_fit"] if col.endswith("_cnn") 

330 ] 

331 for col_name in col_names_cnn_fit: 

332 if col_name in cat_cnn.dtype.names: 332 ↛ 331line 332 didn't jump to line 331 because the condition on line 332 was always true

333 # Initialize with ERR_VAL 

334 cat_save_psf[col_name] = ERR_VAL 

335 

336 # Fill values for stars that were used in CNN 

337 if flags_gaia is not None: 337 ↛ 331line 337 didn't jump to line 331 because the condition on line 337 was always true

338 # Map from CNN indices to full catalog indices 

339 idx_full = np.where(select_gaia)[0][select_cnn] 

340 cat_save_psf[col_name][idx_full] = cat_cnn[col_name][select_cnn] 

341 

342 # Fill GAIA match information 

343 if cat_gaia is not None and flags_gaia is not None: 343 ↛ 348line 343 didn't jump to line 348 because the condition on line 343 was always true

344 cat_save_psf["gaia_ra_match"][select_all] = cat_gaia["gaia_ra_match"] 

345 cat_save_psf["gaia_dec_match"][select_all] = cat_gaia["gaia_dec_match"] 

346 cat_save_psf["gaia_id_match"][select_all] = cat_gaia["gaia_id_match"] 

347 else: 

348 cat_save_psf["FLAGS_STARS"] = 0 

349 

350 # Fill exposure information 

351 cat_save_psf["N_EXPOSURES"] = predictions.get("n_exposures", 1) 

352 

353 def _save_star_cube(self, filepath_cat_out, cnn_results): 

354 """Save star stamp cube for detailed analysis.""" 

355 LOGGER.debug("Saving star cube") 

356 

357 # Check if we have the necessary data 

358 if "cube_cnn" not in cnn_results: 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true

359 LOGGER.warning("No cube data to save") 

360 return 

361 

362 cube = cnn_results["cube_cnn"] 

363 cat_cnn = cnn_results["cat_cnn"] 

364 

365 psf_utils.write_star_cube( 

366 star_cube=cube, cat=cat_cnn, filepath_cat_out=filepath_cat_out 

367 )