Coverage for src/ufig/psf_estimation/polynomial_fitting.py: 96%
134 statements
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
1# Copyright (C) 2025 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
4# created: Mon Jul 21 2025
7import numpy as np
8from cosmic_toolbox import logger
10from ufig.array_util import set_flag_bit
12from . import psf_utils, star_sample_selection_cnn
13from .tiled_regressor import TiledRobustPolynomialRegressor as Regressor
15LOGGER = logger.get_logger(__file__)
18class PolynomialPSFModel:
19 """
20 Handles polynomial interpolation fitting for PSF spatial variation.
22 This class manages:
23 - Polynomial model configuration and fitting
24 - Iterative outlier removal during fitting
25 - Cross-validation for regularization parameter selection
26 - Model validation and quality assessment
27 """
29 def __init__(self, config):
30 """
31 Initialize the polynomial PSF model with configuration parameters.
33 :param config: Configuration dictionary with parameters for polynomial fitting.
34 """
35 self.config = config
37 def fit_model(self, cat, position_weights):
38 """
39 Fit polynomial interpolation model to CNN predictions.
41 """
42 LOGGER.info("Starting polynomial model fitting")
44 # Select stars for fitting and prepare parameters
45 cat_fit, flags_fit, position_weights_fit = self._select_fitting_stars(
46 cat, position_weights
47 )
49 # Prepare parameter columns and settings
50 col_names_fit, col_names_ipt, settings_fit = self._prepare_fitting_config(
51 cat_fit
52 )
54 # Fit model with iterative outlier removal
55 regressor, cat_clip = self._fit_with_outlier_removal(
56 cat_fit, col_names_fit, position_weights_fit, settings_fit
57 )
59 # Handle validation stars if requested and refit if needed
60 if self.config.get("fraction_validation_stars", 0) > 0:
61 regressor = self._handle_validation_stars(
62 cat_fit,
63 cat_clip,
64 flags_fit,
65 col_names_fit,
66 position_weights_fit,
67 settings_fit,
68 )
70 LOGGER.info("Polynomial model fitting complete")
72 return {
73 "regressor": regressor,
74 "cat_fit": cat_fit,
75 "cat_clip": cat_clip,
76 "flags_fit": flags_fit,
77 "col_names_fit": col_names_fit,
78 "col_names_ipt": col_names_ipt,
79 "settings_fit": settings_fit,
80 }
82 def _select_fitting_stars(self, cat_cnn, position_weights):
83 """
84 Select stars for polynomial fitting based on CNN predictions.
86 :param cat_cnn: Catalog of stars with CNN predictions.
87 :param position_weights: Position weights for the stars.
88 :return: Tuple of (selected catalog, flags, position weights).
89 """
90 # Start with all stars that passed CNN quality cuts
91 select_fit = np.ones(len(cat_cnn), dtype=bool)
93 # Filter out stars with NaN values in CNN predictions
94 cnn_columns = [col for col in cat_cnn.dtype.names if col.endswith("_cnn")]
95 for col in cnn_columns:
96 if np.issubdtype(cat_cnn[col].dtype, np.floating): 96 ↛ 95line 96 didn't jump to line 95 because the condition on line 96 was always true
97 select_fit &= np.isfinite(cat_cnn[col])
99 n_valid = np.sum(select_fit)
100 n_total = len(cat_cnn)
101 LOGGER.info(
102 f"CNN NaN filtering: {n_valid}/{n_total} stars have finite CNN predictions"
103 )
105 if n_valid == 0:
106 raise ValueError(
107 "No stars with finite CNN predictions available for polynomial fitting"
108 )
110 cat_fit = cat_cnn[select_fit]
111 flags_fit = np.zeros(len(cat_fit), dtype=np.int32)
112 position_weights_fit = position_weights[select_fit]
114 LOGGER.info(f"Selected {len(cat_fit)} stars for polynomial fitting")
116 return cat_fit, flags_fit, position_weights_fit
118 def _prepare_fitting_config(self, cat_fit):
119 """
120 Prepare parameter columns and fitting configuration.
122 :param cat_fit: Catalog of stars selected for fitting.
123 :param filepath_image: Path to the image file for normalization.
124 :return: Tuple of (parameter names, interpolation names, fitting settings).
125 """
126 # Define parameter columns for fitting
127 col_names_cnn_fit = [
128 col
129 for col in cat_fit.dtype.names
130 if col.endswith("_cnn")
131 and not col.startswith("x_")
132 and not col.startswith("y_")
133 ]
134 col_names_mom_fit = ["se_mom_fwhm", "se_mom_win", "se_mom_e1", "se_mom_e2"]
136 col_names_fit = col_names_cnn_fit + col_names_mom_fit
138 # Convert to interpolation column names
139 col_names_cnn_ipt = self._colnames_cnn_fit_to_ipt(col_names_cnn_fit)
140 col_names_mom_ipt = self._colnames_mom_fit_to_ipt(col_names_mom_fit)
141 col_names_ipt = col_names_cnn_ipt + col_names_mom_ipt
143 # Add astrometry columns if requested
144 if self.config.get("astrometry_errors", False):
145 col_names_astrometry_fit = ["astrometry_diff_x", "astrometry_diff_y"]
146 col_names_astrometry_ipt = [
147 f"{col}_ipt" for col in col_names_astrometry_fit
148 ]
149 col_names_fit += col_names_astrometry_fit
150 col_names_ipt += col_names_astrometry_ipt
152 # Add derivative columns
153 for ax in ["x", "y"]:
154 col_names_ipt += [f"{col}_dd{ax}" for col in col_names_astrometry_ipt]
156 # Setup regularization parameters
157 ridge_alpha = self.config.get("psfmodel_ridge_alpha", 1e-6)
158 if isinstance(ridge_alpha, dict): 158 ↛ 163line 158 didn't jump to line 163 because the condition on line 158 was always true
159 ridge_alpha = [ridge_alpha[p] for p in col_names_fit]
160 LOGGER.info(f"Using parameter-specific ridge alpha: {ridge_alpha}")
162 # Prepare fitting settings
163 settings_fit = {
164 "n_max_refit": self.config.get("n_max_refit", 10),
165 "poly_order": self.config.get("poly_order", 5),
166 "ridge_alpha": ridge_alpha,
167 "polynomial_type": self.config.get("polynomial_type", "chebyshev"),
168 "n_sigma_clip": self.config.get("n_sigma_clip", 3),
169 "scale_pos": np.array(
170 [
171 [
172 self.config["image_shape"][1] / 2,
173 self.config["image_shape"][1] / 2,
174 ],
175 [
176 self.config["image_shape"][0] / 2,
177 self.config["image_shape"][0] / 2,
178 ],
179 ]
180 ),
181 "scale_par": np.array(
182 [
183 [np.mean(cat_fit[p]), np.std(cat_fit[p], ddof=1)]
184 for p in col_names_fit
185 ]
186 ),
187 "raise_underdetermined": self.config.get(
188 "psfmodel_raise_underdetermined_error", True
189 ),
190 }
191 # make sure that scale par has non-zero std and set them to 1
192 if np.any(settings_fit["scale_par"][:, 1] == 0):
193 LOGGER.warning("Some parameters have zero standard deviation, setting to 1")
194 settings_fit["scale_par"][:, 1] = np.where(
195 settings_fit["scale_par"][:, 1] == 0, 1, settings_fit["scale_par"][:, 1]
196 )
198 return col_names_fit, col_names_ipt, settings_fit
200 def _fit_with_outlier_removal(
201 self, cat_fit, col_names_fit, position_weights_fit, settings_fit
202 ):
203 """
204 Fit polynomial model with iterative outlier removal.
205 :param cat_fit: Catalog of stars selected for fitting.
206 :param col_names_fit: Names of parameters to fit.
207 :param position_weights_fit: Position weights for the stars.
208 :param settings_fit: Fitting settings including polynomial order and
209 ridge alpha.
210 :return: Tuple of (fitted regressor, clipped catalog, best ridge alpha).
211 """
213 LOGGER.info("Fitting polynomial model with outlier removal")
215 cat_select = cat_fit.copy()
216 position_weights_select = position_weights_fit.copy()
217 n_stars = len(cat_select)
219 # Get columns to use for outlier detection
220 list_cols_use_outlier = self._get_outlier_removal_columns(col_names_fit)
222 # Iterative fitting with outlier removal
223 for i_fit in range(settings_fit["n_max_refit"]): 223 ↛ 260line 223 didn't jump to line 260 because the loop on line 223 didn't complete
224 LOGGER.debug(f"Fitting iteration {i_fit + 1}, n_stars: {n_stars}")
226 # Fit model
227 regressor = self._fit_single_model(
228 cat_select,
229 col_names_fit,
230 position_weights_select,
231 settings_fit,
232 )
234 # Check for outliers
235 select_keep = self._identify_outliers(
236 regressor,
237 cat_select,
238 col_names_fit,
239 position_weights_select,
240 settings_fit,
241 list_cols_use_outlier,
242 )
244 LOGGER.debug(
245 f"Outlier removal: keeping {np.sum(select_keep)}/"
246 f"{len(select_keep)} stars"
247 )
249 # Update selection
250 cat_select = cat_select[select_keep]
251 position_weights_select = position_weights_select[select_keep]
253 # Check for convergence
254 if n_stars == len(cat_select): 254 ↛ 258line 254 didn't jump to line 258 because the condition on line 254 was always true
255 LOGGER.info("Outlier removal converged")
256 break
258 n_stars = len(cat_select)
260 return regressor, cat_select
262 def _fit_single_model(self, cat, col_names_fit, position_weights, settings_fit):
263 """Fit a single polynomial model."""
264 position_xy = np.stack((cat["X_IMAGE"] - 0.5, cat["Y_IMAGE"] - 0.5), axis=-1)
265 position_par = np.stack([cat[p] for p in col_names_fit], axis=-1)
267 position_xy_transformed = psf_utils.transform_forward(
268 position_xy,
269 settings_fit["scale_pos"],
270 )
271 position_xy_transformed_weights = np.concatenate(
272 (position_xy_transformed, position_weights), axis=1
273 )
274 position_par_transformed = psf_utils.transform_forward(
275 position_par, settings_fit["scale_par"]
276 )
278 # Create and fit regressor
279 regressor = Regressor(
280 poly_order=settings_fit["poly_order"],
281 ridge_alpha=settings_fit["ridge_alpha"],
282 polynomial_type=settings_fit["polynomial_type"],
283 set_unseen_to_mean=True,
284 )
286 regressor.fit(position_xy_transformed_weights, position_par_transformed)
287 return regressor
289 def _identify_outliers(
290 self,
291 regressor,
292 cat,
293 col_names_fit,
294 position_weights,
295 settings_fit,
296 list_cols_use_outlier,
297 ):
298 """Identify outliers based on model residuals."""
299 # Get predictions
300 position_xy = np.stack((cat["X_IMAGE"] - 0.5, cat["Y_IMAGE"] - 0.5), axis=-1)
301 position_par = np.stack([cat[p] for p in col_names_fit], axis=-1)
303 position_xy_transformed = psf_utils.transform_forward(
304 position_xy, settings_fit["scale_pos"]
305 )
306 position_xy_transformed_weights = np.concatenate(
307 (position_xy_transformed, position_weights), axis=1
308 )
309 position_par_transformed = psf_utils.transform_forward(
310 position_par, settings_fit["scale_par"]
311 )
313 # Get model predictions
314 position_par_pred = regressor.predict(position_xy_transformed_weights)
316 # Apply outlier removal
317 select_keep = star_sample_selection_cnn.remove_outliers(
318 x=position_par_transformed,
319 y=position_par_pred,
320 n_sigma=settings_fit["n_sigma_clip"],
321 list_cols_use=list_cols_use_outlier,
322 )
324 return select_keep
326 def _handle_validation_stars(
327 self,
328 cat_fit,
329 cat_clip,
330 flags_fit,
331 col_names_fit,
332 position_weights_fit,
333 settings_fit,
334 ):
335 """Handle validation star selection and final fitting."""
336 fraction_validation = self.config.get("fraction_validation_stars", 0)
338 # Select validation stars from the outlier-cleaned catalog
339 indices_validation = psf_utils.select_validation_stars(
340 len(cat_clip), fraction_validation
341 )
342 cat_validation = cat_clip[indices_validation]
344 # Flag validation stars in the full fitting catalog
345 select_val_outer = np.in1d(cat_fit["NUMBER"], cat_validation["NUMBER"])
346 set_flag_bit(
347 flags=flags_fit,
348 select=select_val_outer,
349 field=star_sample_selection_cnn.FLAGBIT_VALIDATION_STAR,
350 )
352 # Final fit excluding validation stars
353 select_final = flags_fit == 0
354 LOGGER.info(
355 f"Final fit: {len(cat_validation)} validation stars reserved, "
356 f"{np.sum(select_final)} stars for training"
357 )
359 # Fit on non-validation stars
360 regressor = self._fit_single_model(
361 cat_fit[select_final],
362 col_names_fit,
363 position_weights_fit[select_final],
364 settings_fit,
365 )
367 # Test on validation stars
368 validation_metrics = self._test_validation_stars(
369 regressor,
370 cat_validation,
371 col_names_fit,
372 position_weights_fit[select_val_outer],
373 settings_fit,
374 )
375 LOGGER.info(f"Validation test completed: {validation_metrics}")
377 return regressor
379 def _test_validation_stars(
380 self,
381 regressor,
382 cat_validation,
383 col_names_fit,
384 position_weights_validation,
385 settings_fit,
386 ):
387 """
388 Test the fitted model on validation stars and compute metrics.
390 :param regressor: Fitted polynomial regressor.
391 :param cat_validation: Catalog of validation stars.
392 :param col_names_fit: Names of parameters used in fitting.
393 :param position_weights_validation: Position weights for validation stars.
394 :param settings_fit: Fitting settings including scale parameters.
395 :return: Dictionary with validation metrics.
396 """
397 if len(cat_validation) == 0:
398 LOGGER.warning("No validation stars available for testing")
399 return {"n_validation": 0}
401 # Transform validation data
402 position_xy = np.stack(
403 (cat_validation["X_IMAGE"] - 0.5, cat_validation["Y_IMAGE"] - 0.5), axis=-1
404 )
405 position_par_true = np.stack(
406 [cat_validation[p] for p in col_names_fit], axis=-1
407 )
409 position_xy_transformed = psf_utils.transform_forward(
410 position_xy, settings_fit["scale_pos"]
411 )
412 position_xy_transformed_weights = np.concatenate(
413 (position_xy_transformed, position_weights_validation), axis=1
414 )
415 position_par_true_transformed = psf_utils.transform_forward(
416 position_par_true, settings_fit["scale_par"]
417 )
419 # Get model predictions
420 position_par_pred_transformed = regressor.predict(
421 position_xy_transformed_weights
422 )
424 # Calculate residuals in transformed space
425 residuals = position_par_pred_transformed - position_par_true_transformed
427 # Compute validation metrics
428 metrics = {
429 "n_validation": len(cat_validation),
430 "mae_mean": np.mean(np.abs(residuals)), # Mean absolute error
431 "rmse_mean": np.sqrt(np.mean(residuals**2)), # Root mean square error
432 "bias_mean": np.mean(residuals), # Systematic bias
433 }
435 # Per-parameter metrics
436 for i, param in enumerate(col_names_fit):
437 if i < residuals.shape[1]: 437 ↛ 436line 437 didn't jump to line 436 because the condition on line 437 was always true
438 metrics[f"mae_{param}"] = np.mean(np.abs(residuals[:, i]))
439 metrics[f"rmse_{param}"] = np.sqrt(np.mean(residuals[:, i] ** 2))
440 metrics[f"bias_{param}"] = np.mean(residuals[:, i])
442 return metrics
444 def _get_outlier_removal_columns(self, par_fit):
445 """Get columns to use for outlier removal."""
446 list_cols = []
447 for ip, par in enumerate(par_fit):
448 if ("fwhm" in par) or ("e1" in par) or ("e2" in par):
449 list_cols.append(ip)
450 LOGGER.info(f"Using columns {list_cols} for outlier removal")
451 return list_cols
453 def _colnames_cnn_fit_to_ipt(self, col_names_cnn_fit):
454 """Convert CNN fit column names to interpolation column names."""
455 return [p[:-4] + "_ipt" for p in col_names_cnn_fit]
457 def _colnames_mom_fit_to_ipt(self, col_names_mom_fit):
458 """Convert moment fit column names to interpolation column names."""
459 return [p.replace("se_mom", "psf_mom") + "_ipt" for p in col_names_mom_fit]