Coverage for src/ufig/psf_estimation/save_model.py: 90%
133 statements
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
1# Copyright (C) 2025 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
4# created: Mon Jul 21 2025
7import h5py
8import numpy as np
9from cosmic_toolbox import arraytools as at
10from cosmic_toolbox import logger
12from ufig.psf_estimation import psf_utils
14LOGGER = logger.get_logger(__file__)
15HDF5_COMPRESS = {"compression": "gzip", "compression_opts": 9, "shuffle": True}
16ERR_VAL = 999.0
19class PSFSave:
20 """
21 Saves PSF models and predictions to HDF5 format.
23 This class manages:
25 - Saving PSF models to HDF5 format
26 - Writing output catalogs with PSF predictions
27 - Storing diagnostic data and star cubes
28 - Creating grid predictions for visualization
29 """
31 def __init__(self, config):
32 """
33 Initialize the PSF save utility with configuration parameters.
34 """
35 self.config = config
37 def save_psf_model(
38 self,
39 filepath_out,
40 processed_data,
41 cnn_results,
42 model_results,
43 filepath_sysmaps=None,
44 filepath_cat_out=None,
45 ):
46 """
47 Save the PSF model and predictions to an HDF5 file.
48 """
49 LOGGER.info(f"Saving PSF model to: {filepath_out}")
50 # Create the main HDF5 model file
51 self._save_hdf5_model(filepath_out, filepath_sysmaps, model_results)
53 # Generate and save predictions
54 predictions = self._generate_predictions(
55 filepath_out, processed_data, model_results
56 )
58 # Update HDF5 file with predictions
59 self._save_predictions_to_hdf5(filepath_out, predictions)
61 if filepath_cat_out is not None: 61 ↛ 71line 61 didn't jump to line 71 because the condition on line 61 was always true
62 self._save_catalog_output(
63 filepath_cat_out,
64 processed_data,
65 cnn_results,
66 model_results,
67 predictions,
68 )
70 # Save star cube if requested
71 if self.config.get("save_star_cube", False):
72 self._save_star_cube(filepath_cat_out, cnn_results)
74 LOGGER.info("PSF model saving complete")
76 def _save_hdf5_model(self, filepath_out, filepath_sysmaps, model_results):
77 """Save the core PSF model to HDF5 format."""
78 LOGGER.debug("Creating HDF5 PSF model file")
80 with h5py.File(filepath_out, mode="w") as fh5_out:
81 # Copy systematics maps from input file
82 if filepath_sysmaps is not None: 82 ↛ 86line 82 didn't jump to line 86 because the condition on line 82 was always true
83 self._copy_systematics_maps(fh5_out, filepath_sysmaps)
85 # Save fitting settings
86 settings_fit = model_results["settings_fit"]
87 for key in settings_fit:
88 at.replace_hdf5_dataset(
89 fobj=fh5_out, name=f"settings/{key}", data=settings_fit[key]
90 )
92 # Save polynomial model components
93 regressor = model_results["regressor"]
94 col_names_ipt = model_results["col_names_ipt"]
96 at.replace_hdf5_dataset(
97 fobj=fh5_out, name="par_names", data=col_names_ipt, **HDF5_COMPRESS
98 )
99 at.replace_hdf5_dataset(
100 fobj=fh5_out,
101 name="arr_pointings_polycoeffs",
102 data=regressor.arr_pointings_polycoeffs,
103 **HDF5_COMPRESS,
104 )
105 at.replace_hdf5_dataset(
106 fobj=fh5_out,
107 name="unseen_pointings",
108 data=regressor.unseen_pointings,
109 **HDF5_COMPRESS,
110 )
111 at.replace_hdf5_dataset(
112 fobj=fh5_out,
113 name="set_unseen_to_mean",
114 data=regressor.set_unseen_to_mean,
115 )
117 def _copy_systematics_maps(self, fh5_out, filepath_sys):
118 """Copy systematics maps from input file."""
119 LOGGER.debug("Copying systematics maps")
121 with h5py.File(filepath_sys, mode="r") as fh5_sys:
122 for key in fh5_sys:
123 try:
124 fh5_out.create_dataset(name=key, data=fh5_sys[key], **HDF5_COMPRESS)
125 except Exception as err:
126 LOGGER.debug(
127 f"Failed to compress {key}: {err}, copying uncompressed"
128 )
129 fh5_sys.copy(key, fh5_out)
131 # Copy pointing attributes
132 if "map_pointings" in fh5_sys: 132 ↛ exitline 132 didn't jump to the function exit
133 fh5_out["map_pointings"].attrs["n_pointings"] = fh5_sys[
134 "map_pointings"
135 ].attrs["n_pointings"]
137 def _generate_predictions(self, filepath_out, processed_data, model_results):
138 """Generate PSF predictions for various data sets."""
139 LOGGER.debug("Generating PSF predictions")
141 predictions = {}
142 psfmodel_corr_brighter_fatter = self.config.get("psfmodel_corr_brighter_fatter")
144 # Predict for training stars
145 cat_fit = model_results["cat_fit"]
146 predictions["star_train_psf"], _ = self._predict_for_catalog(
147 cat_fit, filepath_out, psfmodel_corr_brighter_fatter
148 )
150 # Predict for full input catalog
151 cat_full = processed_data.get("cat_original", processed_data["cat_gaia"])
152 predictions["cat_psf"], predictions["n_exposures"] = self._predict_for_catalog(
153 cat_full, filepath_out, psfmodel_corr_brighter_fatter
154 )
156 # Generate grid predictions
157 predictions["grid_psf"] = self._generate_grid_predictions(
158 filepath_out,
159 processed_data["image_shape"],
160 )
162 # Add astrometric derivatives if requested
163 if self.config.get("astrometry_errors", False):
164 col_names_astrometry_ipt = [
165 f"{col}_ipt" for col in ["astrometry_diff_x", "astrometry_diff_y"]
166 ]
167 predictions["star_train_psf"] = self._add_model_derivatives(
168 predictions["star_train_psf"],
169 filepath_out,
170 col_names_astrometry_ipt,
171 psfmodel_corr_brighter_fatter,
172 )
173 predictions["cat_psf"] = self._add_model_derivatives(
174 predictions["cat_psf"],
175 filepath_out,
176 col_names_astrometry_ipt,
177 psfmodel_corr_brighter_fatter,
178 )
180 return predictions
182 def _predict_for_catalog(self, cat, filepath_out, psfmodel_corr_brighter_fatter):
183 """Generate PSF predictions for a catalog."""
184 # Import here to avoid circular dependencies
185 from .psf_predictions import predict_psf_for_catalogue_storing
187 return predict_psf_for_catalogue_storing(
188 cat, filepath_out, psfmodel_corr_brighter_fatter
189 )
191 def _generate_grid_predictions(self, filepath_out, img_shape):
192 """
193 Generate PSF predictions on a regular grid.
194 Brighter-fatter correction can not be applied here as it requires
195 specific magnitudes and ellipticities which are not available for grid points.
196 """
197 LOGGER.debug("Generating grid predictions")
199 # Create regular grid across image
200 x_grid, y_grid = np.meshgrid(
201 np.arange(0, img_shape[1], 100), np.arange(0, img_shape[0], 100)
202 )
203 x_grid = x_grid.ravel() + 0.5
204 y_grid = y_grid.ravel() + 0.5
206 # Create catalog for grid points
207 grid_cat = np.empty(len(x_grid), dtype=at.get_dtype(["X_IMAGE", "Y_IMAGE"]))
208 grid_cat["X_IMAGE"] = x_grid
209 grid_cat["Y_IMAGE"] = y_grid
211 # Generate predictions
212 grid_psf, _ = self._predict_for_catalog(
213 grid_cat, filepath_out, psfmodel_corr_brighter_fatter=None
214 )
216 return grid_psf
218 def _add_model_derivatives(
219 self, cat, filepath_out, cols, psfmodel_corr_brighter_fatter, delta=1e-2
220 ):
221 """Add spatial derivative predictions."""
222 # Import here to avoid circular dependencies
223 from .psf_predictions import get_model_derivatives
225 return get_model_derivatives(
226 cat, filepath_out, cols, psfmodel_corr_brighter_fatter, delta
227 )
229 def _save_predictions_to_hdf5(self, filepath_out, predictions):
230 """Save predictions to the HDF5 model file."""
231 LOGGER.debug("Saving predictions to HDF5 file")
233 with h5py.File(filepath_out, mode="r+") as fh5_out:
234 # Save full catalog predictions
235 at.replace_hdf5_dataset(
236 fobj=fh5_out,
237 name="predictions",
238 data=at.set_storing_dtypes(predictions["cat_psf"]),
239 **HDF5_COMPRESS,
240 )
242 # Save training star data and predictions
243 if "star_train_psf" in predictions: 243 ↛ 254line 243 didn't jump to line 254 because the condition on line 243 was always true
244 data = predictions["star_train_psf"]
246 at.replace_hdf5_dataset(
247 fobj=fh5_out,
248 name="star_train_prediction",
249 data=at.set_storing_dtypes(data),
250 **HDF5_COMPRESS,
251 )
253 # Save grid predictions
254 at.replace_hdf5_dataset(
255 fobj=fh5_out,
256 name="grid_psf",
257 data=at.set_storing_dtypes(predictions["grid_psf"]),
258 **HDF5_COMPRESS,
259 )
261 def _save_catalog_output(
262 self, filepath_cat_out, processed_data, cnn_results, model_results, predictions
263 ):
264 """Save catalog output with PSF predictions."""
265 LOGGER.debug(f"Saving catalog output: {filepath_cat_out}")
267 # Get original catalog
268 cat = processed_data.get("cat_original", processed_data["cat_gaia"])
270 # Define output columns
271 col_names_cnn_fit = [
272 col for col in model_results["col_names_fit"] if col.endswith("_cnn")
273 ]
274 col_names_ipt = model_results["col_names_ipt"]
276 cols_new = (
277 col_names_cnn_fit
278 + col_names_ipt
279 + [
280 "FLAGS_STARS:i4",
281 "N_EXPOSURES:i2",
282 "gaia_ra_match",
283 "gaia_dec_match",
284 "gaia_id_match:i8",
285 ]
286 )
288 # Create output catalog
289 cat_save_psf = at.ensure_cols(cat, names=cols_new)
291 # Fill with predictions and metadata
292 self._fill_catalog_output(
293 cat_save_psf, processed_data, cnn_results, model_results, predictions
294 )
296 # Convert precision if requested
297 precision = self.config.get("precision", float)
298 if precision == np.float32:
299 cat_save_psf = at.rec_float64_to_float32(cat_save_psf)
300 LOGGER.info("Converted catalog to float32 precision")
302 # Save catalog
303 at.save_hdf_cols(filepath_cat_out, cat_save_psf)
304 LOGGER.info(f"Catalog saved with {len(cat_save_psf)} sources")
306 def _fill_catalog_output(
307 self, cat_save_psf, processed_data, cnn_results, model_results, predictions
308 ):
309 """Fill catalog output with predictions and metadata."""
310 # Get masks and data
311 cat_gaia = processed_data.get("cat_gaia", None)
312 flags_gaia = processed_data.get("flags_gaia", None)
313 flags_all = processed_data.get("flags_all", None)
314 cat_cnn = cnn_results.get("cat_cnn", None)
315 flags_cnn = cnn_results.get("flags_cnn", None)
316 select_gaia = flags_gaia == 0
317 select_cnn = flags_cnn == 0
318 select_all = flags_all == 0
320 # Fill interpolated PSF parameters
321 col_names_ipt = model_results["col_names_ipt"]
322 for col_name in col_names_ipt:
323 if col_name in predictions["cat_psf"].dtype.names: 323 ↛ 322line 323 didn't jump to line 322 because the condition on line 323 was always true
324 cat_save_psf[col_name] = predictions["cat_psf"][col_name]
326 # Fill CNN predicted parameters for stars used in fitting
327 if cat_cnn is not None and flags_cnn is not None: 327 ↛ 343line 327 didn't jump to line 343 because the condition on line 327 was always true
328 col_names_cnn_fit = [
329 col for col in model_results["col_names_fit"] if col.endswith("_cnn")
330 ]
331 for col_name in col_names_cnn_fit:
332 if col_name in cat_cnn.dtype.names: 332 ↛ 331line 332 didn't jump to line 331 because the condition on line 332 was always true
333 # Initialize with ERR_VAL
334 cat_save_psf[col_name] = ERR_VAL
336 # Fill values for stars that were used in CNN
337 if flags_gaia is not None: 337 ↛ 331line 337 didn't jump to line 331 because the condition on line 337 was always true
338 # Map from CNN indices to full catalog indices
339 idx_full = np.where(select_gaia)[0][select_cnn]
340 cat_save_psf[col_name][idx_full] = cat_cnn[col_name][select_cnn]
342 # Fill GAIA match information
343 if cat_gaia is not None and flags_gaia is not None: 343 ↛ 348line 343 didn't jump to line 348 because the condition on line 343 was always true
344 cat_save_psf["gaia_ra_match"][select_all] = cat_gaia["gaia_ra_match"]
345 cat_save_psf["gaia_dec_match"][select_all] = cat_gaia["gaia_dec_match"]
346 cat_save_psf["gaia_id_match"][select_all] = cat_gaia["gaia_id_match"]
347 else:
348 cat_save_psf["FLAGS_STARS"] = 0
350 # Fill exposure information
351 cat_save_psf["N_EXPOSURES"] = predictions.get("n_exposures", 1)
353 def _save_star_cube(self, filepath_cat_out, cnn_results):
354 """Save star stamp cube for detailed analysis."""
355 LOGGER.debug("Saving star cube")
357 # Check if we have the necessary data
358 if "cube_cnn" not in cnn_results: 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true
359 LOGGER.warning("No cube data to save")
360 return
362 cube = cnn_results["cube_cnn"]
363 cat_cnn = cnn_results["cat_cnn"]
365 psf_utils.write_star_cube(
366 star_cube=cube, cat=cat_cnn, filepath_cat_out=filepath_cat_out
367 )