Coverage for src/ufig/psf_estimation/psf_predictions.py: 92%
77 statements
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
1# Copyright (C) 2025 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
4# created: Mon Jul 28 2025
6import h5py
7import numpy as np
8from cosmic_toolbox import arraytools as at
9from cosmic_toolbox import logger
11from . import correct_brighter_fatter, psf_utils
12from .tiled_regressor import TiledRobustPolynomialRegressor as Regressor
14LOGGER = logger.get_logger(__file__)
15ERR_VAL = 999.0
18def colnames_derivative(cols, ax):
19 """
20 Create column names for derivatives with respect to an axis.
22 :param cols: List of column names to derive from
23 :param ax: Axis for the derivative ('x' or 'y')
24 :return: List of new column names for derivatives
25 """
26 return [f"{c}_dd{ax}" for c in cols]
29def get_model_derivatives(
30 cat, filepath_psfmodel, cols, psfmodel_corr_brighter_fatter, delta=1e-2
31):
32 """
33 Calculate spatial derivatives of PSF model parameters.
35 :param cat: Input catalog with X_IMAGE, Y_IMAGE columns
36 :param filepath_psfmodel: Path to the PSF model file
37 :param cols: List of column names to calculate derivatives for
38 :param psfmodel_corr_brighter_fatter: Parameters for brighter-fatter correction
39 :param delta: Step size for finite difference calculation
40 :return: Catalog with additional columns for derivatives
41 """
42 dims = ["x", "y"]
43 for d in dims:
44 col_names_der = colnames_derivative(cols, d)
45 f = f"{d}_IMAGE".upper()
47 # Calculate model at position - delta/2
48 cat_dm = cat.copy()
49 cat_dm[f] -= delta / 2.0
50 cat_dm = predict_psf_for_catalogue_storing(
51 cat_dm, filepath_psfmodel, psfmodel_corr_brighter_fatter
52 )[0]
54 # Calculate model at position + delta/2
55 cat_dp = cat.copy()
56 cat_dp[f] += delta / 2.0
57 cat_dp = predict_psf_for_catalogue_storing(
58 cat_dp, filepath_psfmodel, psfmodel_corr_brighter_fatter
59 )[0]
61 # Calculate centered finite difference
62 cat = at.add_cols(cat, names=col_names_der)
63 for cv, cd in zip(cols, col_names_der):
64 cat[cd] = (cat_dp[cv] - cat_dm[cv]) / delta
66 return cat
69def predict_psf(position_xy, position_weights, regressor, settings, n_per_chunk=1000):
70 """
71 Predict PSF parameters at given positions using a fitted regressor.
73 :param position_xy: Array of (x,y) positions, shape (n, 2)
74 :param position_weights: Weights for each position, shape (n, m)
75 :param regressor: Fitted regressor object
76 :param settings: Dictionary with model settings, including scale factors
77 :param n_per_chunk: Number of samples to process in each batch
78 :return: Tuple of predicted parameters and a mask for positions with no coverage
79 """
80 position_xy_transformed = psf_utils.transform_forward(
81 position_xy, scale=settings["scale_pos"]
82 )
84 position_xy_transformed_weights = np.concatenate(
85 [position_xy_transformed, position_weights], axis=1
86 )
88 position_par_transformed = regressor.predict(
89 position_xy_transformed_weights, batch_size=n_per_chunk
90 )
92 position_par_post = psf_utils.transform_inverse(
93 position_par_transformed, settings["scale_par"]
94 )
96 select_no_coverage = (position_weights.sum(axis=1) == 0) | np.any(
97 ~np.isfinite(position_par_post), axis=1
98 )
99 position_par_post[select_no_coverage] = 0
101 return position_par_post, select_no_coverage
104def predict_psf_with_file(position_xy, filepath_psfmodel, id_pointing="all"):
105 """
106 Predict PSF parameters at given positions using a saved PSF model file.
108 :param position_xy: Array of (x,y) positions, shape (n, 2)
109 :param filepath_psfmodel: Path to the PSF model file
110 :param id_pointing: ID of the pointing to use, or 'all'
111 :return: Generator yielding predicted parameters and number of exposures
112 """
113 if position_xy.shape[1] != 2: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true
114 raise ValueError(
115 f"Invalid position_xy shape (should be n_obj x 2) {position_xy.shape}"
116 )
118 # Setup interpolator
119 with h5py.File(filepath_psfmodel, "r") as fh5:
120 par_names = at.set_loading_dtypes(fh5["par_names"][...])
121 pointings_maps = fh5["map_pointings"]
122 position_weights = psf_utils.get_position_weights(
123 position_xy[:, 0], position_xy[:, 1], pointings_maps
124 )
125 poly_coeffs = fh5["arr_pointings_polycoeffs"][...]
126 unseen_pointings = fh5["unseen_pointings"][...]
127 settings = {
128 key: at.set_loading_dtypes(fh5["settings"][key][...])
129 for key in fh5["settings"]
130 }
131 settings.setdefault("polynomial_type", "chebyshev")
132 LOGGER.debug(f"polynomial_type={settings['polynomial_type']}")
134 # Add debugging information
135 if "scale_par" in settings: 135 ↛ 139line 135 didn't jump to line 139
136 LOGGER.debug(f"Loaded scale_par shape: {settings['scale_par'].shape}")
137 LOGGER.debug(f"First few scale_par values: {settings['scale_par'][:3]}")
139 regressor = Regressor(
140 poly_order=settings["poly_order"],
141 ridge_alpha=settings["ridge_alpha"],
142 polynomial_type=settings["polynomial_type"],
143 poly_coefficients=poly_coeffs,
144 unseen_pointings=unseen_pointings,
145 )
147 if id_pointing == "all": 147 ↛ 164line 147 didn't jump to line 164 because the condition on line 147 was always true
148 LOGGER.info(
149 f"prediction for cnn models n_pos={position_xy.shape[0]} id_pointing=all"
150 )
152 position_par_post, select_no_coverage = predict_psf(
153 position_xy, position_weights, regressor, settings
154 )
156 position_par_post = np.core.records.fromarrays(
157 position_par_post.T, names=",".join(par_names)
158 )
159 psf_utils.postprocess_catalog(position_par_post)
160 n_exposures = psf_utils.position_weights_to_nexp(position_weights)
161 yield position_par_post, select_no_coverage, n_exposures
163 else:
164 raise NotImplementedError(
165 "This feature is not yet implemented due to the polynomial coefficients"
166 " covariances which are a tiny bit tricky but definitely doable"
167 )
170def predict_psf_for_catalogue(
171 cat, filepath_psfmodel, id_pointing="all", psfmodel_corr_brighter_fatter=None
172):
173 """
174 Predict PSF parameters for a given catalog.
176 :param cat: Input catalog with X_IMAGE, Y_IMAGE (and MAG_AUTO if applicable) columns
177 :param filepath_psfmodel: Path to the PSF model file
178 :param id_pointing: ID of the pointing to use, or 'all'
179 :param psfmodel_corr_brighter_fatter: Parameters for brighter-fatter correction
180 :return: Generator yielding predicted parameters and number of exposures
181 """
182 position_xy = np.stack((cat["X_IMAGE"] - 0.5, cat["Y_IMAGE"] - 0.5), axis=-1)
184 for position_par_post, select_no_coverage, n_exposures in predict_psf_with_file( 184 ↛ exitline 184 didn't return from function 'predict_psf_for_catalogue' because the loop on line 184 didn't complete
185 position_xy, filepath_psfmodel, id_pointing=id_pointing
186 ):
187 # Set points without coverage to error value
188 for par_name in position_par_post.dtype.names:
189 position_par_post[par_name][select_no_coverage] = ERR_VAL
191 # Apply brighter-fatter correction
192 if (
193 psfmodel_corr_brighter_fatter is not None
194 and "apply_to_galaxies" in psfmodel_corr_brighter_fatter
195 and psfmodel_corr_brighter_fatter["apply_to_galaxies"]
196 ):
197 LOGGER.info("added brighter-fatter to PSF prediction")
198 (
199 position_par_post["psf_fwhm_ipt"],
200 position_par_post["psf_e1_ipt"],
201 position_par_post["psf_e2_ipt"],
202 ) = correct_brighter_fatter.brighter_fatter_add(
203 col_mag=cat["MAG_AUTO"],
204 col_fwhm=position_par_post["psf_fwhm_ipt"],
205 col_e1=position_par_post["psf_e1_ipt"],
206 col_e2=position_par_post["psf_e2_ipt"],
207 dict_corr=psfmodel_corr_brighter_fatter,
208 )
210 yield position_par_post, n_exposures
213def predict_psf_for_catalogue_storing(
214 cat_in, filepath_psf_model, psfmodel_corr_brighter_fatter
215):
216 """
217 Predict PSF parameters for a catalog and format for storage.
219 :param cat_in: Input catalog with X_IMAGE, Y_IMAGE columns
220 :param filepath_psf_model: Path to the PSF model file
221 :param psfmodel_corr_brighter_fatter: Parameters for brighter-fatter correction
222 :return: Tuple of output catalog with predicted parameters and number of exposures
223 """
224 # Predict using the generator function
225 position_par_post, n_exposures = next(
226 predict_psf_for_catalogue(
227 cat=cat_in,
228 filepath_psfmodel=filepath_psf_model,
229 psfmodel_corr_brighter_fatter=psfmodel_corr_brighter_fatter,
230 )
231 )
233 # Create output catalog with proper structure
234 if "id" in cat_in.dtype.names: 234 ↛ 236line 234 didn't jump to line 236 because the condition on line 234 was never true
235 # if it is a simulated catalog, we also want to add the id of the galaxy
236 dtypes = at.get_dtype(
237 ("X_IMAGE", "Y_IMAGE", "id") + position_par_post.dtype.names
238 )
239 else:
240 dtypes = at.get_dtype(("X_IMAGE", "Y_IMAGE") + position_par_post.dtype.names)
242 cat_out = np.empty(len(cat_in), dtype=dtypes)
244 # Copy predicted parameters
245 for par_name in position_par_post.dtype.names:
246 cat_out[par_name] = position_par_post[par_name]
248 # Copy position information
249 cat_out["X_IMAGE"] = cat_in["X_IMAGE"]
250 cat_out["Y_IMAGE"] = cat_in["Y_IMAGE"]
252 return cat_out, n_exposures