Coverage for src/ufig/psf_estimation/star_sample_selection_cnn.py: 97%

194 statements  

« prev     ^ index     » next       coverage.py v7.10.2, created at 2025-08-07 15:17 +0000

1# Copyright (c) 2017 ETH Zurich, Cosmology Research Group 

2""" 

3Created on Aug 03, 2017 

4@author: Joerg Herbel 

5""" 

6 

7import h5py 

8import numpy as np 

9import scipy.stats 

10from astropy import wcs 

11from astropy.io import fits 

12from cosmic_toolbox import arraytools as at 

13from cosmic_toolbox import logger 

14from sklearn.neighbors import BallTree 

15 

16from ufig.array_util import check_flag_bit, set_flag_bit 

17from ufig.plugins import add_generic_stamp_flags 

18from ufig.psf_estimation import psf_utils 

19from ufig.psf_estimation.cutouts_utils import get_cutouts 

20 

21LOGGER = logger.get_logger(__file__) 

22 

23FLAGBIT_GAIA = 1 

24FLAGBIT_MAG = 2 

25FLAGBIT_N_EXP = 3 

26FLAGBIT_SOURCEEXTRACTOR_FLAGS = 4 

27FLAGBIT_POSITION_WEIGHTS = 5 

28FLAGBIT_IMAGE_BOUNDARY = 6 

29FLAGBIT_COADD_BOUNDARY = 7 

30FLAGBIT_MAX_FLUX_CENTERED = 8 

31FLAGBIT_BETA = 9 

32FLAGBIT_FWHM = 10 

33FLAGBIT_ELLIPTICITY = 11 

34FLAGBIT_FLEXION = 12 

35FLAGBIT_KURTOSIS = 13 

36FLAGBIT_OUTLIER = 14 

37FLAGBIT_VALIDATION_STAR = 15 

38FLAGBIT_MOM_ERROR = 16 

39FLAGBIT_SYSMAP_DELTA_WEIGHT = 17 

40FLAGBIT_NEARBY_BRIGHT_STAR = 18 

41FLAGBIT_SURVEY_MASK = 19 

42 

43 

44def inlims(x, lims): 

45 """ 

46 Check if values in x are within the specified limits. 

47 

48 :param x: Array of values to check. 

49 :param lims: Tuple of (min, max) limits. 

50 :return: Boolean array indicating whether each value is within the limits. 

51 """ 

52 

53 return (x > lims[0]) & (x < lims[1]) 

54 

55 

56def get_gaia_match(cat_in, filepath_gaia, max_dist_arcsec): 

57 """ 

58 Match the objects from the SExtractor catalog with the GAIA catalog based on their 

59 sky coordinates with a maximum distance defined by max_dist_arcsec. 

60 The function returns a catalog with additional columns to flag the matches and 

61 the matched RA/Dec coordinates from GAIA. 

62 

63 :param cat_in: SExtractor catalog data. 

64 :param filepath_gaia: Path to the GAIA catalog file. 

65 :param max_dist_arcsec: Maximum distance in arcseconds for matching. 

66 :return: Catalog enriched with GAIA matching information 

67 """ 

68 

69 with h5py.File(filepath_gaia, mode="r") as fh5: 

70 cat_gaia = np.array(fh5["data"]) 

71 LOGGER.info(f"found {len(cat_gaia)} stars in the Gaia catalogue") 

72 

73 match_vector = get_match_vector(cat_in, cat_gaia, max_dist_arcsec) 

74 

75 cat_out = at.ensure_cols( 

76 cat_in, 

77 names=["match_gaia:i2", "gaia_ra_match", "gaia_dec_match", "gaia_id_match:i8"], 

78 ) 

79 cat_out["gaia_id_match"] = match_vector[3] 

80 

81 cat_out["match_gaia"] = match_vector[0] 

82 cat_out["gaia_ra_match"] = match_vector[1] 

83 cat_out["gaia_dec_match"] = match_vector[2] 

84 

85 return cat_out 

86 

87 

88def get_gaia_image_coords(filepath_image, cat): 

89 """ 

90 Convert GAIA coordinates (RA/Dec) to image pixel coordinates. 

91 

92 :param filepath_image: Path to the image file for coordinate conversion. 

93 :param cat: Catalog containing GAIA coordinates. 

94 :return: Catalog with additional columns for GAIA pixel coordinates. 

95 """ 

96 try: 

97 header = fits.getheader(filepath_image, ext=1) 

98 except IndexError: 

99 header = fits.getheader(filepath_image, ext=0) 

100 wcsobj = wcs.WCS(header) 

101 skycoords = np.vstack([cat["gaia_ra_match"], cat["gaia_dec_match"]]).T 

102 pixelcoords = wcsobj.all_world2pix(skycoords, 1) 

103 cat = at.ensure_cols(cat, names=["gaia_x_match", "gaia_y_match"]) 

104 cat["gaia_x_match"] = pixelcoords[:, 0] 

105 cat["gaia_y_match"] = pixelcoords[:, 1] 

106 select = np.isnan(cat["gaia_x_match"]) 

107 cat["gaia_x_match"][select] = cat["gaia_ra_match"][select] # unmatched 

108 cat["gaia_y_match"][select] = cat["gaia_dec_match"][select] # unmatched 

109 return cat 

110 

111 

112def cut_magnitude(cat, mag_min, mag_max): 

113 """ 

114 Select stars based on their magnitude. 

115 

116 :param cat: Catalog with SExtractor measurements. 

117 :param mag_min: Minimum magnitude for selection. 

118 :param mag_max: Maximum magnitude for selection. 

119 :return: Boolean array indicating which stars are within the magnitude range. 

120 """ 

121 select = (cat["MAG_AUTO"] > mag_min) & (cat["MAG_AUTO"] < mag_max) 

122 return select 

123 

124 

125def cut_n_exp(weights, n_exp_min): 

126 """ 

127 Select stars based on the number of exposures they are covered by. 

128 

129 :param weights: Position weights for each star. 

130 :param n_exp_min: Minimum number of exposures required for selection. 

131 :return: Boolean array indicating which stars have enough exposures. 

132 """ 

133 n_exp = psf_utils.position_weights_to_nexp(weights) 

134 select = n_exp >= n_exp_min 

135 return select 

136 

137 

138def cut_sextractor_flag(cat, flags=None): 

139 """ 

140 Select stars based on SExtractor flags. 

141 

142 :param cat: Catalog with SExtractor measurements. 

143 :param flags: list of flags to consider, if None, all flags are considered. 

144 :return: Boolean array indicating which stars have acceptable SExtractor flags. 

145 """ 

146 if flags is None: 

147 return np.ones(len(cat), dtype=bool) 

148 if isinstance(flags, int): 

149 select = cat["FLAGS"] == flags 

150 elif isinstance(flags, (list, tuple, np.ndarray)): 

151 select = np.zeros(len(cat), dtype=bool) 

152 for flag in flags: 

153 select |= cat["FLAGS"] == flag 

154 else: 

155 raise ValueError("flags must be an int, list, tuple or numpy array") 

156 return select 

157 

158 

159def cut_position_weights(weights): 

160 """ 

161 Select only stars with finite weights 

162 

163 :param weights: Position weights for each star. 

164 :return: Boolean array indicating which stars have valid position weights. 

165 """ 

166 select = np.all(np.isfinite(weights), axis=1) 

167 return select 

168 

169 

170def cut_nearby_bright_star(cat): 

171 if "FLAGS_STAMP" in cat.dtype.names: 

172 select = ~check_flag_bit( 

173 cat["FLAGS_STAMP"], add_generic_stamp_flags.FLAGBIT_NEARBY_BRIGHT_STAR 

174 ) 

175 else: 

176 LOGGER.warning( 

177 "column FLAGS_STAMP not found, not cutting on FLAGBIT_NEARBY_BRIGHT_STAR" 

178 ) 

179 select = np.ones(len(cat), dtype=bool) 

180 

181 return select 

182 

183 

184def cut_sysmaps_delta_weight(cat): 

185 if "FLAGS_STAMP" in cat.dtype.names: 

186 select = ~check_flag_bit( 

187 cat["FLAGS_STAMP"], add_generic_stamp_flags.FLAGBIT_SYSMAP_DELTA_WEIGHT 

188 ) 

189 else: 

190 LOGGER.warning( 

191 "column FLAGS_STAMP not found, not cutting on FLAGBIT_SYSMAP_DELTA_WEIGHT" 

192 ) 

193 select = np.ones(len(cat), dtype=bool) 

194 

195 return select 

196 

197 

198def cut_sysmaps_survey_mask(cat): 

199 if "FLAGS_STAMP" in cat.dtype.names: 

200 select = ~check_flag_bit( 

201 cat["FLAGS_STAMP"], add_generic_stamp_flags.FLAGBIT_SURVEY_MASK 

202 ) 

203 else: 

204 LOGGER.warning( 

205 "column FLAGS_STAMP not found, not cutting on FLAGBIT_SURVEY_MASK" 

206 ) 

207 select = np.ones(len(cat), dtype=bool) 

208 

209 return select 

210 

211 

212def cut_boundaries(cat, image, pointings_maps, star_stamp_shape): 

213 """ 

214 Checks the position of the star stamps within the image and returns boolean arrays 

215 indicating whether the stars are within the image boundaries, coadd boundaries and 

216 whether the cutout is centered on the maximum flux pixel. 

217 

218 :param cat: Catalog with SExtractor measurements. 

219 :param image: Image data for cutout extraction. 

220 :param pointings_maps: Pointing maps for exposure coverage. 

221 :param star_stamp_shape: Shape of the cutout stamps for stars. 

222 """ 

223 ( 

224 cube, 

225 select_image_boundary, 

226 select_coadd_boundary, 

227 select_max_flux_centered, 

228 ) = get_cutouts( 

229 cat["XPEAK_IMAGE"] - 1, 

230 cat["YPEAK_IMAGE"] - 1, 

231 image, 

232 pointings_maps, 

233 star_stamp_shape, 

234 ) 

235 return select_image_boundary, select_coadd_boundary, select_max_flux_centered, cube 

236 

237 

238def set_flag_bit_and_log(select, flags, flagbit, msg): 

239 set_flag_bit(flags=flags, select=select, field=flagbit) 

240 LOGGER.info( 

241 f"star cuts: found {np.count_nonzero(~select)}/{len(flags)} {msg}, " 

242 f"current n_star={np.count_nonzero(flags == 0)}" 

243 ) 

244 

245 

246def get_stars_for_cnn( 

247 cat, 

248 image, 

249 star_stamp_shape, 

250 pointings_maps, 

251 position_weights, 

252 star_mag_range=(18, 22), 

253 min_n_exposures=1, 

254 sextractor_flags=None, 

255 flag_coadd_boundaries=False, 

256 moments_lim=(-99, 99), 

257): 

258 """ 

259 Select the stars for the CNN model based on various quality cuts. 

260 

261 :param cat: Catalog with SExtractor measurements. 

262 :param image: Image data for cutout extraction. 

263 :param star_stamp_shape: Shape of the cutout stamps for stars. 

264 :param pointings_maps: Pointing maps for exposure coverage. 

265 :param position_weights: Position weights for each star. 

266 :param star_mag_range: Tuple defining the magnitude range for star selection. 

267 :param min_n_exposures: Minimum number of exposures required for a star to be 

268 selected. 

269 :param sextractor_flags: SExtractor flags to consider for quality cuts. 

270 :return: flags, cube: Flags indicating the quality of each star and the 

271 cutout cube of selected stars. 

272 """ 

273 

274 flags = np.zeros(len(cat), dtype=np.uint32) 

275 

276 # Magnitude 

277 select_mag = cut_magnitude(cat, star_mag_range[0], star_mag_range[1]) 

278 set_flag_bit_and_log( 

279 ~select_mag, flags, FLAGBIT_MAG, msg="stars in accepted mag range" 

280 ) 

281 

282 # Number of exposures 

283 select_n_exp = cut_n_exp(position_weights, min_n_exposures) 

284 set_flag_bit_and_log( 

285 ~select_n_exp, 

286 flags, 

287 FLAGBIT_N_EXP, 

288 msg="stars with accepted number of exposures", 

289 ) 

290 

291 # SExtractor flags 

292 if sextractor_flags is None: 292 ↛ 293line 292 didn't jump to line 293 because the condition on line 292 was never true

293 sextractor_flags = [0, 16] 

294 select_sourceextractor_flags = cut_sextractor_flag(cat, sextractor_flags) 

295 set_flag_bit_and_log( 

296 ~select_sourceextractor_flags, 

297 flags, 

298 FLAGBIT_SOURCEEXTRACTOR_FLAGS, 

299 msg="stars with accepted SourceExtractor flags", 

300 ) 

301 

302 # NaN weights 

303 select_position_weights = cut_position_weights(position_weights) 

304 set_flag_bit_and_log( 

305 ~select_position_weights, 

306 flags, 

307 FLAGBIT_POSITION_WEIGHTS, 

308 msg="stars without any NaN position weights", 

309 ) 

310 

311 # Get cutouts 

312 ( 

313 select_image_boundary, 

314 select_coadd_boundary, 

315 select_max_flux_centered, 

316 cube, 

317 ) = cut_boundaries(cat, image, pointings_maps, star_stamp_shape) 

318 set_flag_bit_and_log( 

319 ~select_image_boundary, 

320 flags, 

321 FLAGBIT_IMAGE_BOUNDARY, 

322 msg="stars far enough from image boundary", 

323 ) 

324 if flag_coadd_boundaries: 324 ↛ 331line 324 didn't jump to line 331 because the condition on line 324 was always true

325 set_flag_bit_and_log( 

326 ~select_coadd_boundary, 

327 flags, 

328 FLAGBIT_COADD_BOUNDARY, 

329 msg="stars far enough from coadd boundaries", 

330 ) 

331 set_flag_bit_and_log( 

332 ~select_max_flux_centered, 

333 flags, 

334 FLAGBIT_MAX_FLUX_CENTERED, 

335 msg="stars centered on maximum flux", 

336 ) 

337 

338 select_moments = cut_moments(cat, moments_lim) 

339 set_flag_bit_and_log( 

340 ~select_moments, flags, FLAGBIT_MOM_ERROR, msg="stars with good moments" 

341 ) 

342 

343 select_ok_sysmap_delta_weight = cut_sysmaps_delta_weight(cat) 

344 set_flag_bit_and_log( 

345 ~select_ok_sysmap_delta_weight, 

346 flags, 

347 FLAGBIT_SYSMAP_DELTA_WEIGHT, 

348 msg="stars with matching sysmap delta weight", 

349 ) 

350 

351 select_no_bright_star = cut_nearby_bright_star(cat) 

352 set_flag_bit_and_log( 

353 ~select_no_bright_star, 

354 flags, 

355 FLAGBIT_NEARBY_BRIGHT_STAR, 

356 msg="stars far enough from bright stars", 

357 ) 

358 

359 select_ok_survey_mask = cut_sysmaps_survey_mask(cat) 

360 set_flag_bit_and_log( 

361 ~select_ok_survey_mask, 

362 flags, 

363 FLAGBIT_SURVEY_MASK, 

364 msg="stars with OK survey mask", 

365 ) 

366 

367 return flags, cube 

368 

369 

370def beta_cut(beta, beta_lim=(1.5, 10)): 

371 """ 

372 Select stars based on their beta parameter. 

373 

374 :param beta: Beta parameter for each star. 

375 :param beta_lim: Tuple defining the limits for beta selection. 

376 :return: Boolean array indicating which stars have beta within the specified limits. 

377 """ 

378 select = inlims(beta, beta_lim) 

379 return select 

380 

381 

382def fwhm_cut(fwhm, fwhm_lim=(1, 10)): 

383 """ 

384 Select stars based on their FWHM parameter. 

385 

386 :param fwhm: FWHM parameter for each star. 

387 :param fwhm_lim: Tuple defining the limits for FWHM selection. 

388 :return: Boolean array indicating which stars have FWHM within the specified limits. 

389 """ 

390 select = inlims(fwhm, fwhm_lim) 

391 return select 

392 

393 

394def ellipticity_cut(e1, e2, ellipticity_lim=(-0.3, 0.3)): 

395 """ 

396 Select stars based on their ellipticity parameters. 

397 

398 :param e1: First ellipticity component for each star. 

399 :param e2: Second ellipticity component for each star. 

400 :param ellipticity_lim: Tuple defining the limits for ellipticity selection. 

401 :return: Boolean array indicating which stars have ellipticity within the 

402 specified limits. 

403 """ 

404 select = inlims(e1, ellipticity_lim) & inlims(e2, ellipticity_lim) 

405 return select 

406 

407 

408def flexion_cut(f1, f2, g1, g2, flexion_lim=(-0.3, 0.3)): 

409 """ 

410 Select stars based on their flexion parameters. 

411 

412 :param f1: First flexion component for each star. 

413 :param f2: Second flexion component for each star. 

414 :param g1: Third flexion component for each star. 

415 :param g2: Fourth flexion component for each star. 

416 :param flexion_lim: Tuple defining the limits for flexion selection. 

417 :return: Boolean array indicating which stars have flexion within the 

418 specified limits. 

419 """ 

420 select = ( 

421 inlims(f1, flexion_lim) 

422 & inlims(f2, flexion_lim) 

423 & inlims(g1, flexion_lim) 

424 & inlims(g2, flexion_lim) 

425 ) 

426 return select 

427 

428 

429def kurtosis_cut(kurtosis, kurtosis_lim=(-1, 1)): 

430 """ 

431 Select stars based on their kurtosis parameter. 

432 

433 :param kurtosis: Kurtosis parameter for each star. 

434 :param kurtosis_lim: Tuple defining the limits for kurtosis selection. 

435 :return: Boolean array indicating which stars have kurtosis within the 

436 specified limits. 

437 """ 

438 select = inlims(kurtosis, kurtosis_lim) 

439 return select 

440 

441 

442def cut_moments(cat, moments_lim): 

443 select = ( 

444 inlims(cat["se_mom_fwhm"], moments_lim) 

445 & inlims(cat["se_mom_e1"], moments_lim) 

446 & inlims(cat["se_mom_e2"], moments_lim) 

447 ) 

448 return select 

449 

450 

451def select_cnn_predictions( 

452 flags, 

453 pred, 

454 beta_lim=(1.5, 10), 

455 fwhm_lim=(1, 10), 

456 ellipticity_lim=(-0.3, 0.3), 

457 flexion_lim=(-0.3, 0.3), 

458 kurtosis_lim=(-1, 1), 

459): 

460 """ 

461 Apply various cuts to the CNN predictions to select stars based on their 

462 PSF parameters. 

463 

464 :param flags: Flags array indicating the quality of each star. 

465 :param pred: Predictions from the CNN model containing PSF parameters. 

466 :return: Updated flags after applying the cuts. 

467 """ 

468 

469 # Beta 

470 if "psf_beta_1_cnn" in pred.dtype.names: 

471 select_beta = beta_cut(pred["psf_beta_1_cnn"], beta_lim=beta_lim) 

472 

473 if "psf_beta_2_cnn" in pred.dtype.names: 473 ↛ 476line 473 didn't jump to line 476 because the condition on line 473 was always true

474 select_beta &= beta_cut(pred["psf_beta_2_cnn"], beta_lim=beta_lim) 

475 

476 set_flag_bit_and_log( 

477 ~select_beta, flags, FLAGBIT_BETA, "stars with accepted beta" 

478 ) 

479 

480 # FWHM 

481 select_fwhm = fwhm_cut(pred["psf_fwhm_cnn"], fwhm_lim=fwhm_lim) 

482 set_flag_bit_and_log(~select_fwhm, flags, FLAGBIT_FWHM, "stars with accepted FWHM") 

483 

484 # Ellipticity 

485 select_ellip = ellipticity_cut( 

486 pred["psf_e1_cnn"], pred["psf_e2_cnn"], ellipticity_lim=ellipticity_lim 

487 ) 

488 set_flag_bit_and_log( 

489 ~select_ellip, flags, FLAGBIT_ELLIPTICITY, "stars with accepted ellipticity" 

490 ) 

491 

492 # Flexion 

493 select_flexion = flexion_cut( 

494 pred["psf_f1_cnn"], 

495 pred["psf_f2_cnn"], 

496 pred["psf_g1_cnn"], 

497 pred["psf_g2_cnn"], 

498 flexion_lim=flexion_lim, 

499 ) 

500 set_flag_bit_and_log( 

501 ~select_flexion, flags, FLAGBIT_FLEXION, "stars with accepted flexion" 

502 ) 

503 

504 # Kurtosis 

505 if "psf_kurtosis_cnn" in pred.dtype.names: 

506 select_kurtosis = kurtosis_cut( 

507 pred["psf_kurtosis_cnn"], kurtosis_lim=kurtosis_lim 

508 ) 

509 set_flag_bit_and_log( 

510 ~select_kurtosis, flags, FLAGBIT_KURTOSIS, "stars with accepted kurtosis" 

511 ) 

512 

513 return flags 

514 

515 

516def remove_outliers(x, y, n_sigma, list_cols_use="all"): 

517 """ 

518 Remove outliers based on the difference between predicted and actual values. 

519 

520 :param x: Predicted values (e.g., CNN predictions). 

521 :param y: Actual values (e.g., SExtractor measurements). 

522 :param n_sigma: Number of standard deviations for clipping. 

523 :param list_cols_use: List of columns to use for outlier detection, or 'all' 

524 to use all columns. 

525 :return: Boolean array indicating which samples are not outliers. 

526 """ 

527 

528 n_samples, n_dim = x.shape 

529 

530 if list_cols_use == "all": 530 ↛ 531line 530 didn't jump to line 531 because the condition on line 530 was never true

531 list_cols_use = range(n_dim) 

532 

533 select_clip = np.ones(n_samples, dtype=bool) 

534 

535 res = y - x 

536 

537 for i_dim in list_cols_use: 

538 _, lo, up = scipy.stats.sigmaclip(res[:, i_dim], low=n_sigma, high=n_sigma) 

539 select_clip &= (res[:, i_dim] < up) & (res[:, i_dim] > lo) 

540 

541 return select_clip 

542 

543 

544def get_match_vector(cat, cat_gaia, max_dist_arcsec): 

545 """ 

546 Match the SExtractor catalog with the GAIA catalog based on RA/Dec coordinates. 

547 The matching is done using a nearest neighbor search with a maximum distance 

548 defined by max_dist_arcsec. 

549 

550 :param cat: SExtractor catalog data. 

551 :param cat_gaia: GAIA catalog data. 

552 :param max_dist_arcsec: Maximum distance in arcseconds for matching. 

553 :return: Tuple of boolean selection array and matched RA/Dec coordinates 

554 from GAIA.""" 

555 

556 gaia_ang = np.concatenate( 

557 [ 

558 cat_gaia["dec"][:, np.newaxis] * np.pi / 180.0, 

559 cat_gaia["ra"][:, np.newaxis] * np.pi / 180.0, 

560 ], 

561 axis=1, 

562 ) 

563 cat_ang = np.concatenate( 

564 [ 

565 cat["DELTAWIN_J2000"][:, np.newaxis] * np.pi / 180.0, 

566 cat["ALPHAWIN_J2000"][:, np.newaxis] * np.pi / 180.0, 

567 ], 

568 axis=1, 

569 ) 

570 

571 # calculate nearest neighbours 

572 ball_tree = BallTree(gaia_ang, metric="haversine") 

573 dist, ind = ball_tree.query(cat_ang, k=1) 

574 dist_arcsec = dist[:, 0] / np.pi * 180.0 * 3600.0 

575 select_in_gaia = dist_arcsec < max_dist_arcsec 

576 

577 # get matched 

578 ind_match = ind[:, 0][select_in_gaia] 

579 

580 # remove those with conflicts 

581 un, ui, uc = np.unique(ind_match, return_counts=True, return_inverse=True) 

582 select_unique_matches = uc[ui] == 1 

583 select_in_gaia[select_in_gaia] &= select_unique_matches 

584 ind_match = ind_match[select_unique_matches] 

585 

586 # assign 

587 gaia_ra_match = np.full_like(cat["ALPHAWIN_J2000"], -200) 

588 gaia_dec_match = np.full_like(cat["DELTAWIN_J2000"], -200) 

589 gaia_id_match = np.full(len(cat), -200, dtype=int) 

590 gaia_ra_match[select_in_gaia] = cat_gaia["ra"][ind_match] 

591 gaia_dec_match[select_in_gaia] = cat_gaia["dec"][ind_match] 

592 

593 gaia_id_match[select_in_gaia] = cat_gaia["id"][ind_match] 

594 

595 return select_in_gaia, gaia_ra_match, gaia_dec_match, gaia_id_match