Coverage for src/ufig/psf_estimation/polynomial_fitting.py: 96%

134 statements  

« prev     ^ index     » next       coverage.py v7.10.2, created at 2025-08-07 15:17 +0000

1# Copyright (C) 2025 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4# created: Mon Jul 21 2025 

5 

6 

7import numpy as np 

8from cosmic_toolbox import logger 

9 

10from ufig.array_util import set_flag_bit 

11 

12from . import psf_utils, star_sample_selection_cnn 

13from .tiled_regressor import TiledRobustPolynomialRegressor as Regressor 

14 

15LOGGER = logger.get_logger(__file__) 

16 

17 

18class PolynomialPSFModel: 

19 """ 

20 Handles polynomial interpolation fitting for PSF spatial variation. 

21 

22 This class manages: 

23 - Polynomial model configuration and fitting 

24 - Iterative outlier removal during fitting 

25 - Cross-validation for regularization parameter selection 

26 - Model validation and quality assessment 

27 """ 

28 

29 def __init__(self, config): 

30 """ 

31 Initialize the polynomial PSF model with configuration parameters. 

32 

33 :param config: Configuration dictionary with parameters for polynomial fitting. 

34 """ 

35 self.config = config 

36 

37 def fit_model(self, cat, position_weights): 

38 """ 

39 Fit polynomial interpolation model to CNN predictions. 

40 

41 """ 

42 LOGGER.info("Starting polynomial model fitting") 

43 

44 # Select stars for fitting and prepare parameters 

45 cat_fit, flags_fit, position_weights_fit = self._select_fitting_stars( 

46 cat, position_weights 

47 ) 

48 

49 # Prepare parameter columns and settings 

50 col_names_fit, col_names_ipt, settings_fit = self._prepare_fitting_config( 

51 cat_fit 

52 ) 

53 

54 # Fit model with iterative outlier removal 

55 regressor, cat_clip = self._fit_with_outlier_removal( 

56 cat_fit, col_names_fit, position_weights_fit, settings_fit 

57 ) 

58 

59 # Handle validation stars if requested and refit if needed 

60 if self.config.get("fraction_validation_stars", 0) > 0: 

61 regressor = self._handle_validation_stars( 

62 cat_fit, 

63 cat_clip, 

64 flags_fit, 

65 col_names_fit, 

66 position_weights_fit, 

67 settings_fit, 

68 ) 

69 

70 LOGGER.info("Polynomial model fitting complete") 

71 

72 return { 

73 "regressor": regressor, 

74 "cat_fit": cat_fit, 

75 "cat_clip": cat_clip, 

76 "flags_fit": flags_fit, 

77 "col_names_fit": col_names_fit, 

78 "col_names_ipt": col_names_ipt, 

79 "settings_fit": settings_fit, 

80 } 

81 

82 def _select_fitting_stars(self, cat_cnn, position_weights): 

83 """ 

84 Select stars for polynomial fitting based on CNN predictions. 

85 

86 :param cat_cnn: Catalog of stars with CNN predictions. 

87 :param position_weights: Position weights for the stars. 

88 :return: Tuple of (selected catalog, flags, position weights). 

89 """ 

90 # Start with all stars that passed CNN quality cuts 

91 select_fit = np.ones(len(cat_cnn), dtype=bool) 

92 

93 # Filter out stars with NaN values in CNN predictions 

94 cnn_columns = [col for col in cat_cnn.dtype.names if col.endswith("_cnn")] 

95 for col in cnn_columns: 

96 if np.issubdtype(cat_cnn[col].dtype, np.floating): 96 ↛ 95line 96 didn't jump to line 95 because the condition on line 96 was always true

97 select_fit &= np.isfinite(cat_cnn[col]) 

98 

99 n_valid = np.sum(select_fit) 

100 n_total = len(cat_cnn) 

101 LOGGER.info( 

102 f"CNN NaN filtering: {n_valid}/{n_total} stars have finite CNN predictions" 

103 ) 

104 

105 if n_valid == 0: 

106 raise ValueError( 

107 "No stars with finite CNN predictions available for polynomial fitting" 

108 ) 

109 

110 cat_fit = cat_cnn[select_fit] 

111 flags_fit = np.zeros(len(cat_fit), dtype=np.int32) 

112 position_weights_fit = position_weights[select_fit] 

113 

114 LOGGER.info(f"Selected {len(cat_fit)} stars for polynomial fitting") 

115 

116 return cat_fit, flags_fit, position_weights_fit 

117 

118 def _prepare_fitting_config(self, cat_fit): 

119 """ 

120 Prepare parameter columns and fitting configuration. 

121 

122 :param cat_fit: Catalog of stars selected for fitting. 

123 :param filepath_image: Path to the image file for normalization. 

124 :return: Tuple of (parameter names, interpolation names, fitting settings). 

125 """ 

126 # Define parameter columns for fitting 

127 col_names_cnn_fit = [ 

128 col 

129 for col in cat_fit.dtype.names 

130 if col.endswith("_cnn") 

131 and not col.startswith("x_") 

132 and not col.startswith("y_") 

133 ] 

134 col_names_mom_fit = ["se_mom_fwhm", "se_mom_win", "se_mom_e1", "se_mom_e2"] 

135 

136 col_names_fit = col_names_cnn_fit + col_names_mom_fit 

137 

138 # Convert to interpolation column names 

139 col_names_cnn_ipt = self._colnames_cnn_fit_to_ipt(col_names_cnn_fit) 

140 col_names_mom_ipt = self._colnames_mom_fit_to_ipt(col_names_mom_fit) 

141 col_names_ipt = col_names_cnn_ipt + col_names_mom_ipt 

142 

143 # Add astrometry columns if requested 

144 if self.config.get("astrometry_errors", False): 

145 col_names_astrometry_fit = ["astrometry_diff_x", "astrometry_diff_y"] 

146 col_names_astrometry_ipt = [ 

147 f"{col}_ipt" for col in col_names_astrometry_fit 

148 ] 

149 col_names_fit += col_names_astrometry_fit 

150 col_names_ipt += col_names_astrometry_ipt 

151 

152 # Add derivative columns 

153 for ax in ["x", "y"]: 

154 col_names_ipt += [f"{col}_dd{ax}" for col in col_names_astrometry_ipt] 

155 

156 # Setup regularization parameters 

157 ridge_alpha = self.config.get("psfmodel_ridge_alpha", 1e-6) 

158 if isinstance(ridge_alpha, dict): 158 ↛ 163line 158 didn't jump to line 163 because the condition on line 158 was always true

159 ridge_alpha = [ridge_alpha[p] for p in col_names_fit] 

160 LOGGER.info(f"Using parameter-specific ridge alpha: {ridge_alpha}") 

161 

162 # Prepare fitting settings 

163 settings_fit = { 

164 "n_max_refit": self.config.get("n_max_refit", 10), 

165 "poly_order": self.config.get("poly_order", 5), 

166 "ridge_alpha": ridge_alpha, 

167 "polynomial_type": self.config.get("polynomial_type", "chebyshev"), 

168 "n_sigma_clip": self.config.get("n_sigma_clip", 3), 

169 "scale_pos": np.array( 

170 [ 

171 [ 

172 self.config["image_shape"][1] / 2, 

173 self.config["image_shape"][1] / 2, 

174 ], 

175 [ 

176 self.config["image_shape"][0] / 2, 

177 self.config["image_shape"][0] / 2, 

178 ], 

179 ] 

180 ), 

181 "scale_par": np.array( 

182 [ 

183 [np.mean(cat_fit[p]), np.std(cat_fit[p], ddof=1)] 

184 for p in col_names_fit 

185 ] 

186 ), 

187 "raise_underdetermined": self.config.get( 

188 "psfmodel_raise_underdetermined_error", True 

189 ), 

190 } 

191 # make sure that scale par has non-zero std and set them to 1 

192 if np.any(settings_fit["scale_par"][:, 1] == 0): 

193 LOGGER.warning("Some parameters have zero standard deviation, setting to 1") 

194 settings_fit["scale_par"][:, 1] = np.where( 

195 settings_fit["scale_par"][:, 1] == 0, 1, settings_fit["scale_par"][:, 1] 

196 ) 

197 

198 return col_names_fit, col_names_ipt, settings_fit 

199 

200 def _fit_with_outlier_removal( 

201 self, cat_fit, col_names_fit, position_weights_fit, settings_fit 

202 ): 

203 """ 

204 Fit polynomial model with iterative outlier removal. 

205 :param cat_fit: Catalog of stars selected for fitting. 

206 :param col_names_fit: Names of parameters to fit. 

207 :param position_weights_fit: Position weights for the stars. 

208 :param settings_fit: Fitting settings including polynomial order and 

209 ridge alpha. 

210 :return: Tuple of (fitted regressor, clipped catalog, best ridge alpha). 

211 """ 

212 

213 LOGGER.info("Fitting polynomial model with outlier removal") 

214 

215 cat_select = cat_fit.copy() 

216 position_weights_select = position_weights_fit.copy() 

217 n_stars = len(cat_select) 

218 

219 # Get columns to use for outlier detection 

220 list_cols_use_outlier = self._get_outlier_removal_columns(col_names_fit) 

221 

222 # Iterative fitting with outlier removal 

223 for i_fit in range(settings_fit["n_max_refit"]): 223 ↛ 260line 223 didn't jump to line 260 because the loop on line 223 didn't complete

224 LOGGER.debug(f"Fitting iteration {i_fit + 1}, n_stars: {n_stars}") 

225 

226 # Fit model 

227 regressor = self._fit_single_model( 

228 cat_select, 

229 col_names_fit, 

230 position_weights_select, 

231 settings_fit, 

232 ) 

233 

234 # Check for outliers 

235 select_keep = self._identify_outliers( 

236 regressor, 

237 cat_select, 

238 col_names_fit, 

239 position_weights_select, 

240 settings_fit, 

241 list_cols_use_outlier, 

242 ) 

243 

244 LOGGER.debug( 

245 f"Outlier removal: keeping {np.sum(select_keep)}/" 

246 f"{len(select_keep)} stars" 

247 ) 

248 

249 # Update selection 

250 cat_select = cat_select[select_keep] 

251 position_weights_select = position_weights_select[select_keep] 

252 

253 # Check for convergence 

254 if n_stars == len(cat_select): 254 ↛ 258line 254 didn't jump to line 258 because the condition on line 254 was always true

255 LOGGER.info("Outlier removal converged") 

256 break 

257 

258 n_stars = len(cat_select) 

259 

260 return regressor, cat_select 

261 

262 def _fit_single_model(self, cat, col_names_fit, position_weights, settings_fit): 

263 """Fit a single polynomial model.""" 

264 position_xy = np.stack((cat["X_IMAGE"] - 0.5, cat["Y_IMAGE"] - 0.5), axis=-1) 

265 position_par = np.stack([cat[p] for p in col_names_fit], axis=-1) 

266 

267 position_xy_transformed = psf_utils.transform_forward( 

268 position_xy, 

269 settings_fit["scale_pos"], 

270 ) 

271 position_xy_transformed_weights = np.concatenate( 

272 (position_xy_transformed, position_weights), axis=1 

273 ) 

274 position_par_transformed = psf_utils.transform_forward( 

275 position_par, settings_fit["scale_par"] 

276 ) 

277 

278 # Create and fit regressor 

279 regressor = Regressor( 

280 poly_order=settings_fit["poly_order"], 

281 ridge_alpha=settings_fit["ridge_alpha"], 

282 polynomial_type=settings_fit["polynomial_type"], 

283 set_unseen_to_mean=True, 

284 ) 

285 

286 regressor.fit(position_xy_transformed_weights, position_par_transformed) 

287 return regressor 

288 

289 def _identify_outliers( 

290 self, 

291 regressor, 

292 cat, 

293 col_names_fit, 

294 position_weights, 

295 settings_fit, 

296 list_cols_use_outlier, 

297 ): 

298 """Identify outliers based on model residuals.""" 

299 # Get predictions 

300 position_xy = np.stack((cat["X_IMAGE"] - 0.5, cat["Y_IMAGE"] - 0.5), axis=-1) 

301 position_par = np.stack([cat[p] for p in col_names_fit], axis=-1) 

302 

303 position_xy_transformed = psf_utils.transform_forward( 

304 position_xy, settings_fit["scale_pos"] 

305 ) 

306 position_xy_transformed_weights = np.concatenate( 

307 (position_xy_transformed, position_weights), axis=1 

308 ) 

309 position_par_transformed = psf_utils.transform_forward( 

310 position_par, settings_fit["scale_par"] 

311 ) 

312 

313 # Get model predictions 

314 position_par_pred = regressor.predict(position_xy_transformed_weights) 

315 

316 # Apply outlier removal 

317 select_keep = star_sample_selection_cnn.remove_outliers( 

318 x=position_par_transformed, 

319 y=position_par_pred, 

320 n_sigma=settings_fit["n_sigma_clip"], 

321 list_cols_use=list_cols_use_outlier, 

322 ) 

323 

324 return select_keep 

325 

326 def _handle_validation_stars( 

327 self, 

328 cat_fit, 

329 cat_clip, 

330 flags_fit, 

331 col_names_fit, 

332 position_weights_fit, 

333 settings_fit, 

334 ): 

335 """Handle validation star selection and final fitting.""" 

336 fraction_validation = self.config.get("fraction_validation_stars", 0) 

337 

338 # Select validation stars from the outlier-cleaned catalog 

339 indices_validation = psf_utils.select_validation_stars( 

340 len(cat_clip), fraction_validation 

341 ) 

342 cat_validation = cat_clip[indices_validation] 

343 

344 # Flag validation stars in the full fitting catalog 

345 select_val_outer = np.in1d(cat_fit["NUMBER"], cat_validation["NUMBER"]) 

346 set_flag_bit( 

347 flags=flags_fit, 

348 select=select_val_outer, 

349 field=star_sample_selection_cnn.FLAGBIT_VALIDATION_STAR, 

350 ) 

351 

352 # Final fit excluding validation stars 

353 select_final = flags_fit == 0 

354 LOGGER.info( 

355 f"Final fit: {len(cat_validation)} validation stars reserved, " 

356 f"{np.sum(select_final)} stars for training" 

357 ) 

358 

359 # Fit on non-validation stars 

360 regressor = self._fit_single_model( 

361 cat_fit[select_final], 

362 col_names_fit, 

363 position_weights_fit[select_final], 

364 settings_fit, 

365 ) 

366 

367 # Test on validation stars 

368 validation_metrics = self._test_validation_stars( 

369 regressor, 

370 cat_validation, 

371 col_names_fit, 

372 position_weights_fit[select_val_outer], 

373 settings_fit, 

374 ) 

375 LOGGER.info(f"Validation test completed: {validation_metrics}") 

376 

377 return regressor 

378 

379 def _test_validation_stars( 

380 self, 

381 regressor, 

382 cat_validation, 

383 col_names_fit, 

384 position_weights_validation, 

385 settings_fit, 

386 ): 

387 """ 

388 Test the fitted model on validation stars and compute metrics. 

389 

390 :param regressor: Fitted polynomial regressor. 

391 :param cat_validation: Catalog of validation stars. 

392 :param col_names_fit: Names of parameters used in fitting. 

393 :param position_weights_validation: Position weights for validation stars. 

394 :param settings_fit: Fitting settings including scale parameters. 

395 :return: Dictionary with validation metrics. 

396 """ 

397 if len(cat_validation) == 0: 

398 LOGGER.warning("No validation stars available for testing") 

399 return {"n_validation": 0} 

400 

401 # Transform validation data 

402 position_xy = np.stack( 

403 (cat_validation["X_IMAGE"] - 0.5, cat_validation["Y_IMAGE"] - 0.5), axis=-1 

404 ) 

405 position_par_true = np.stack( 

406 [cat_validation[p] for p in col_names_fit], axis=-1 

407 ) 

408 

409 position_xy_transformed = psf_utils.transform_forward( 

410 position_xy, settings_fit["scale_pos"] 

411 ) 

412 position_xy_transformed_weights = np.concatenate( 

413 (position_xy_transformed, position_weights_validation), axis=1 

414 ) 

415 position_par_true_transformed = psf_utils.transform_forward( 

416 position_par_true, settings_fit["scale_par"] 

417 ) 

418 

419 # Get model predictions 

420 position_par_pred_transformed = regressor.predict( 

421 position_xy_transformed_weights 

422 ) 

423 

424 # Calculate residuals in transformed space 

425 residuals = position_par_pred_transformed - position_par_true_transformed 

426 

427 # Compute validation metrics 

428 metrics = { 

429 "n_validation": len(cat_validation), 

430 "mae_mean": np.mean(np.abs(residuals)), # Mean absolute error 

431 "rmse_mean": np.sqrt(np.mean(residuals**2)), # Root mean square error 

432 "bias_mean": np.mean(residuals), # Systematic bias 

433 } 

434 

435 # Per-parameter metrics 

436 for i, param in enumerate(col_names_fit): 

437 if i < residuals.shape[1]: 437 ↛ 436line 437 didn't jump to line 436 because the condition on line 437 was always true

438 metrics[f"mae_{param}"] = np.mean(np.abs(residuals[:, i])) 

439 metrics[f"rmse_{param}"] = np.sqrt(np.mean(residuals[:, i] ** 2)) 

440 metrics[f"bias_{param}"] = np.mean(residuals[:, i]) 

441 

442 return metrics 

443 

444 def _get_outlier_removal_columns(self, par_fit): 

445 """Get columns to use for outlier removal.""" 

446 list_cols = [] 

447 for ip, par in enumerate(par_fit): 

448 if ("fwhm" in par) or ("e1" in par) or ("e2" in par): 

449 list_cols.append(ip) 

450 LOGGER.info(f"Using columns {list_cols} for outlier removal") 

451 return list_cols 

452 

453 def _colnames_cnn_fit_to_ipt(self, col_names_cnn_fit): 

454 """Convert CNN fit column names to interpolation column names.""" 

455 return [p[:-4] + "_ipt" for p in col_names_cnn_fit] 

456 

457 def _colnames_mom_fit_to_ipt(self, col_names_mom_fit): 

458 """Convert moment fit column names to interpolation column names.""" 

459 return [p.replace("se_mom", "psf_mom") + "_ipt" for p in col_names_mom_fit]