Coverage for src/ufig/psf_estimation/star_sample_selection_cnn.py: 97%
194 statements
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
1# Copyright (c) 2017 ETH Zurich, Cosmology Research Group
2"""
3Created on Aug 03, 2017
4@author: Joerg Herbel
5"""
7import h5py
8import numpy as np
9import scipy.stats
10from astropy import wcs
11from astropy.io import fits
12from cosmic_toolbox import arraytools as at
13from cosmic_toolbox import logger
14from sklearn.neighbors import BallTree
16from ufig.array_util import check_flag_bit, set_flag_bit
17from ufig.plugins import add_generic_stamp_flags
18from ufig.psf_estimation import psf_utils
19from ufig.psf_estimation.cutouts_utils import get_cutouts
21LOGGER = logger.get_logger(__file__)
23FLAGBIT_GAIA = 1
24FLAGBIT_MAG = 2
25FLAGBIT_N_EXP = 3
26FLAGBIT_SOURCEEXTRACTOR_FLAGS = 4
27FLAGBIT_POSITION_WEIGHTS = 5
28FLAGBIT_IMAGE_BOUNDARY = 6
29FLAGBIT_COADD_BOUNDARY = 7
30FLAGBIT_MAX_FLUX_CENTERED = 8
31FLAGBIT_BETA = 9
32FLAGBIT_FWHM = 10
33FLAGBIT_ELLIPTICITY = 11
34FLAGBIT_FLEXION = 12
35FLAGBIT_KURTOSIS = 13
36FLAGBIT_OUTLIER = 14
37FLAGBIT_VALIDATION_STAR = 15
38FLAGBIT_MOM_ERROR = 16
39FLAGBIT_SYSMAP_DELTA_WEIGHT = 17
40FLAGBIT_NEARBY_BRIGHT_STAR = 18
41FLAGBIT_SURVEY_MASK = 19
44def inlims(x, lims):
45 """
46 Check if values in x are within the specified limits.
48 :param x: Array of values to check.
49 :param lims: Tuple of (min, max) limits.
50 :return: Boolean array indicating whether each value is within the limits.
51 """
53 return (x > lims[0]) & (x < lims[1])
56def get_gaia_match(cat_in, filepath_gaia, max_dist_arcsec):
57 """
58 Match the objects from the SExtractor catalog with the GAIA catalog based on their
59 sky coordinates with a maximum distance defined by max_dist_arcsec.
60 The function returns a catalog with additional columns to flag the matches and
61 the matched RA/Dec coordinates from GAIA.
63 :param cat_in: SExtractor catalog data.
64 :param filepath_gaia: Path to the GAIA catalog file.
65 :param max_dist_arcsec: Maximum distance in arcseconds for matching.
66 :return: Catalog enriched with GAIA matching information
67 """
69 with h5py.File(filepath_gaia, mode="r") as fh5:
70 cat_gaia = np.array(fh5["data"])
71 LOGGER.info(f"found {len(cat_gaia)} stars in the Gaia catalogue")
73 match_vector = get_match_vector(cat_in, cat_gaia, max_dist_arcsec)
75 cat_out = at.ensure_cols(
76 cat_in,
77 names=["match_gaia:i2", "gaia_ra_match", "gaia_dec_match", "gaia_id_match:i8"],
78 )
79 cat_out["gaia_id_match"] = match_vector[3]
81 cat_out["match_gaia"] = match_vector[0]
82 cat_out["gaia_ra_match"] = match_vector[1]
83 cat_out["gaia_dec_match"] = match_vector[2]
85 return cat_out
88def get_gaia_image_coords(filepath_image, cat):
89 """
90 Convert GAIA coordinates (RA/Dec) to image pixel coordinates.
92 :param filepath_image: Path to the image file for coordinate conversion.
93 :param cat: Catalog containing GAIA coordinates.
94 :return: Catalog with additional columns for GAIA pixel coordinates.
95 """
96 try:
97 header = fits.getheader(filepath_image, ext=1)
98 except IndexError:
99 header = fits.getheader(filepath_image, ext=0)
100 wcsobj = wcs.WCS(header)
101 skycoords = np.vstack([cat["gaia_ra_match"], cat["gaia_dec_match"]]).T
102 pixelcoords = wcsobj.all_world2pix(skycoords, 1)
103 cat = at.ensure_cols(cat, names=["gaia_x_match", "gaia_y_match"])
104 cat["gaia_x_match"] = pixelcoords[:, 0]
105 cat["gaia_y_match"] = pixelcoords[:, 1]
106 select = np.isnan(cat["gaia_x_match"])
107 cat["gaia_x_match"][select] = cat["gaia_ra_match"][select] # unmatched
108 cat["gaia_y_match"][select] = cat["gaia_dec_match"][select] # unmatched
109 return cat
112def cut_magnitude(cat, mag_min, mag_max):
113 """
114 Select stars based on their magnitude.
116 :param cat: Catalog with SExtractor measurements.
117 :param mag_min: Minimum magnitude for selection.
118 :param mag_max: Maximum magnitude for selection.
119 :return: Boolean array indicating which stars are within the magnitude range.
120 """
121 select = (cat["MAG_AUTO"] > mag_min) & (cat["MAG_AUTO"] < mag_max)
122 return select
125def cut_n_exp(weights, n_exp_min):
126 """
127 Select stars based on the number of exposures they are covered by.
129 :param weights: Position weights for each star.
130 :param n_exp_min: Minimum number of exposures required for selection.
131 :return: Boolean array indicating which stars have enough exposures.
132 """
133 n_exp = psf_utils.position_weights_to_nexp(weights)
134 select = n_exp >= n_exp_min
135 return select
138def cut_sextractor_flag(cat, flags=None):
139 """
140 Select stars based on SExtractor flags.
142 :param cat: Catalog with SExtractor measurements.
143 :param flags: list of flags to consider, if None, all flags are considered.
144 :return: Boolean array indicating which stars have acceptable SExtractor flags.
145 """
146 if flags is None:
147 return np.ones(len(cat), dtype=bool)
148 if isinstance(flags, int):
149 select = cat["FLAGS"] == flags
150 elif isinstance(flags, (list, tuple, np.ndarray)):
151 select = np.zeros(len(cat), dtype=bool)
152 for flag in flags:
153 select |= cat["FLAGS"] == flag
154 else:
155 raise ValueError("flags must be an int, list, tuple or numpy array")
156 return select
159def cut_position_weights(weights):
160 """
161 Select only stars with finite weights
163 :param weights: Position weights for each star.
164 :return: Boolean array indicating which stars have valid position weights.
165 """
166 select = np.all(np.isfinite(weights), axis=1)
167 return select
170def cut_nearby_bright_star(cat):
171 if "FLAGS_STAMP" in cat.dtype.names:
172 select = ~check_flag_bit(
173 cat["FLAGS_STAMP"], add_generic_stamp_flags.FLAGBIT_NEARBY_BRIGHT_STAR
174 )
175 else:
176 LOGGER.warning(
177 "column FLAGS_STAMP not found, not cutting on FLAGBIT_NEARBY_BRIGHT_STAR"
178 )
179 select = np.ones(len(cat), dtype=bool)
181 return select
184def cut_sysmaps_delta_weight(cat):
185 if "FLAGS_STAMP" in cat.dtype.names:
186 select = ~check_flag_bit(
187 cat["FLAGS_STAMP"], add_generic_stamp_flags.FLAGBIT_SYSMAP_DELTA_WEIGHT
188 )
189 else:
190 LOGGER.warning(
191 "column FLAGS_STAMP not found, not cutting on FLAGBIT_SYSMAP_DELTA_WEIGHT"
192 )
193 select = np.ones(len(cat), dtype=bool)
195 return select
198def cut_sysmaps_survey_mask(cat):
199 if "FLAGS_STAMP" in cat.dtype.names:
200 select = ~check_flag_bit(
201 cat["FLAGS_STAMP"], add_generic_stamp_flags.FLAGBIT_SURVEY_MASK
202 )
203 else:
204 LOGGER.warning(
205 "column FLAGS_STAMP not found, not cutting on FLAGBIT_SURVEY_MASK"
206 )
207 select = np.ones(len(cat), dtype=bool)
209 return select
212def cut_boundaries(cat, image, pointings_maps, star_stamp_shape):
213 """
214 Checks the position of the star stamps within the image and returns boolean arrays
215 indicating whether the stars are within the image boundaries, coadd boundaries and
216 whether the cutout is centered on the maximum flux pixel.
218 :param cat: Catalog with SExtractor measurements.
219 :param image: Image data for cutout extraction.
220 :param pointings_maps: Pointing maps for exposure coverage.
221 :param star_stamp_shape: Shape of the cutout stamps for stars.
222 """
223 (
224 cube,
225 select_image_boundary,
226 select_coadd_boundary,
227 select_max_flux_centered,
228 ) = get_cutouts(
229 cat["XPEAK_IMAGE"] - 1,
230 cat["YPEAK_IMAGE"] - 1,
231 image,
232 pointings_maps,
233 star_stamp_shape,
234 )
235 return select_image_boundary, select_coadd_boundary, select_max_flux_centered, cube
238def set_flag_bit_and_log(select, flags, flagbit, msg):
239 set_flag_bit(flags=flags, select=select, field=flagbit)
240 LOGGER.info(
241 f"star cuts: found {np.count_nonzero(~select)}/{len(flags)} {msg}, "
242 f"current n_star={np.count_nonzero(flags == 0)}"
243 )
246def get_stars_for_cnn(
247 cat,
248 image,
249 star_stamp_shape,
250 pointings_maps,
251 position_weights,
252 star_mag_range=(18, 22),
253 min_n_exposures=1,
254 sextractor_flags=None,
255 flag_coadd_boundaries=False,
256 moments_lim=(-99, 99),
257):
258 """
259 Select the stars for the CNN model based on various quality cuts.
261 :param cat: Catalog with SExtractor measurements.
262 :param image: Image data for cutout extraction.
263 :param star_stamp_shape: Shape of the cutout stamps for stars.
264 :param pointings_maps: Pointing maps for exposure coverage.
265 :param position_weights: Position weights for each star.
266 :param star_mag_range: Tuple defining the magnitude range for star selection.
267 :param min_n_exposures: Minimum number of exposures required for a star to be
268 selected.
269 :param sextractor_flags: SExtractor flags to consider for quality cuts.
270 :return: flags, cube: Flags indicating the quality of each star and the
271 cutout cube of selected stars.
272 """
274 flags = np.zeros(len(cat), dtype=np.uint32)
276 # Magnitude
277 select_mag = cut_magnitude(cat, star_mag_range[0], star_mag_range[1])
278 set_flag_bit_and_log(
279 ~select_mag, flags, FLAGBIT_MAG, msg="stars in accepted mag range"
280 )
282 # Number of exposures
283 select_n_exp = cut_n_exp(position_weights, min_n_exposures)
284 set_flag_bit_and_log(
285 ~select_n_exp,
286 flags,
287 FLAGBIT_N_EXP,
288 msg="stars with accepted number of exposures",
289 )
291 # SExtractor flags
292 if sextractor_flags is None: 292 ↛ 293line 292 didn't jump to line 293 because the condition on line 292 was never true
293 sextractor_flags = [0, 16]
294 select_sourceextractor_flags = cut_sextractor_flag(cat, sextractor_flags)
295 set_flag_bit_and_log(
296 ~select_sourceextractor_flags,
297 flags,
298 FLAGBIT_SOURCEEXTRACTOR_FLAGS,
299 msg="stars with accepted SourceExtractor flags",
300 )
302 # NaN weights
303 select_position_weights = cut_position_weights(position_weights)
304 set_flag_bit_and_log(
305 ~select_position_weights,
306 flags,
307 FLAGBIT_POSITION_WEIGHTS,
308 msg="stars without any NaN position weights",
309 )
311 # Get cutouts
312 (
313 select_image_boundary,
314 select_coadd_boundary,
315 select_max_flux_centered,
316 cube,
317 ) = cut_boundaries(cat, image, pointings_maps, star_stamp_shape)
318 set_flag_bit_and_log(
319 ~select_image_boundary,
320 flags,
321 FLAGBIT_IMAGE_BOUNDARY,
322 msg="stars far enough from image boundary",
323 )
324 if flag_coadd_boundaries: 324 ↛ 331line 324 didn't jump to line 331 because the condition on line 324 was always true
325 set_flag_bit_and_log(
326 ~select_coadd_boundary,
327 flags,
328 FLAGBIT_COADD_BOUNDARY,
329 msg="stars far enough from coadd boundaries",
330 )
331 set_flag_bit_and_log(
332 ~select_max_flux_centered,
333 flags,
334 FLAGBIT_MAX_FLUX_CENTERED,
335 msg="stars centered on maximum flux",
336 )
338 select_moments = cut_moments(cat, moments_lim)
339 set_flag_bit_and_log(
340 ~select_moments, flags, FLAGBIT_MOM_ERROR, msg="stars with good moments"
341 )
343 select_ok_sysmap_delta_weight = cut_sysmaps_delta_weight(cat)
344 set_flag_bit_and_log(
345 ~select_ok_sysmap_delta_weight,
346 flags,
347 FLAGBIT_SYSMAP_DELTA_WEIGHT,
348 msg="stars with matching sysmap delta weight",
349 )
351 select_no_bright_star = cut_nearby_bright_star(cat)
352 set_flag_bit_and_log(
353 ~select_no_bright_star,
354 flags,
355 FLAGBIT_NEARBY_BRIGHT_STAR,
356 msg="stars far enough from bright stars",
357 )
359 select_ok_survey_mask = cut_sysmaps_survey_mask(cat)
360 set_flag_bit_and_log(
361 ~select_ok_survey_mask,
362 flags,
363 FLAGBIT_SURVEY_MASK,
364 msg="stars with OK survey mask",
365 )
367 return flags, cube
370def beta_cut(beta, beta_lim=(1.5, 10)):
371 """
372 Select stars based on their beta parameter.
374 :param beta: Beta parameter for each star.
375 :param beta_lim: Tuple defining the limits for beta selection.
376 :return: Boolean array indicating which stars have beta within the specified limits.
377 """
378 select = inlims(beta, beta_lim)
379 return select
382def fwhm_cut(fwhm, fwhm_lim=(1, 10)):
383 """
384 Select stars based on their FWHM parameter.
386 :param fwhm: FWHM parameter for each star.
387 :param fwhm_lim: Tuple defining the limits for FWHM selection.
388 :return: Boolean array indicating which stars have FWHM within the specified limits.
389 """
390 select = inlims(fwhm, fwhm_lim)
391 return select
394def ellipticity_cut(e1, e2, ellipticity_lim=(-0.3, 0.3)):
395 """
396 Select stars based on their ellipticity parameters.
398 :param e1: First ellipticity component for each star.
399 :param e2: Second ellipticity component for each star.
400 :param ellipticity_lim: Tuple defining the limits for ellipticity selection.
401 :return: Boolean array indicating which stars have ellipticity within the
402 specified limits.
403 """
404 select = inlims(e1, ellipticity_lim) & inlims(e2, ellipticity_lim)
405 return select
408def flexion_cut(f1, f2, g1, g2, flexion_lim=(-0.3, 0.3)):
409 """
410 Select stars based on their flexion parameters.
412 :param f1: First flexion component for each star.
413 :param f2: Second flexion component for each star.
414 :param g1: Third flexion component for each star.
415 :param g2: Fourth flexion component for each star.
416 :param flexion_lim: Tuple defining the limits for flexion selection.
417 :return: Boolean array indicating which stars have flexion within the
418 specified limits.
419 """
420 select = (
421 inlims(f1, flexion_lim)
422 & inlims(f2, flexion_lim)
423 & inlims(g1, flexion_lim)
424 & inlims(g2, flexion_lim)
425 )
426 return select
429def kurtosis_cut(kurtosis, kurtosis_lim=(-1, 1)):
430 """
431 Select stars based on their kurtosis parameter.
433 :param kurtosis: Kurtosis parameter for each star.
434 :param kurtosis_lim: Tuple defining the limits for kurtosis selection.
435 :return: Boolean array indicating which stars have kurtosis within the
436 specified limits.
437 """
438 select = inlims(kurtosis, kurtosis_lim)
439 return select
442def cut_moments(cat, moments_lim):
443 select = (
444 inlims(cat["se_mom_fwhm"], moments_lim)
445 & inlims(cat["se_mom_e1"], moments_lim)
446 & inlims(cat["se_mom_e2"], moments_lim)
447 )
448 return select
451def select_cnn_predictions(
452 flags,
453 pred,
454 beta_lim=(1.5, 10),
455 fwhm_lim=(1, 10),
456 ellipticity_lim=(-0.3, 0.3),
457 flexion_lim=(-0.3, 0.3),
458 kurtosis_lim=(-1, 1),
459):
460 """
461 Apply various cuts to the CNN predictions to select stars based on their
462 PSF parameters.
464 :param flags: Flags array indicating the quality of each star.
465 :param pred: Predictions from the CNN model containing PSF parameters.
466 :return: Updated flags after applying the cuts.
467 """
469 # Beta
470 if "psf_beta_1_cnn" in pred.dtype.names:
471 select_beta = beta_cut(pred["psf_beta_1_cnn"], beta_lim=beta_lim)
473 if "psf_beta_2_cnn" in pred.dtype.names: 473 ↛ 476line 473 didn't jump to line 476 because the condition on line 473 was always true
474 select_beta &= beta_cut(pred["psf_beta_2_cnn"], beta_lim=beta_lim)
476 set_flag_bit_and_log(
477 ~select_beta, flags, FLAGBIT_BETA, "stars with accepted beta"
478 )
480 # FWHM
481 select_fwhm = fwhm_cut(pred["psf_fwhm_cnn"], fwhm_lim=fwhm_lim)
482 set_flag_bit_and_log(~select_fwhm, flags, FLAGBIT_FWHM, "stars with accepted FWHM")
484 # Ellipticity
485 select_ellip = ellipticity_cut(
486 pred["psf_e1_cnn"], pred["psf_e2_cnn"], ellipticity_lim=ellipticity_lim
487 )
488 set_flag_bit_and_log(
489 ~select_ellip, flags, FLAGBIT_ELLIPTICITY, "stars with accepted ellipticity"
490 )
492 # Flexion
493 select_flexion = flexion_cut(
494 pred["psf_f1_cnn"],
495 pred["psf_f2_cnn"],
496 pred["psf_g1_cnn"],
497 pred["psf_g2_cnn"],
498 flexion_lim=flexion_lim,
499 )
500 set_flag_bit_and_log(
501 ~select_flexion, flags, FLAGBIT_FLEXION, "stars with accepted flexion"
502 )
504 # Kurtosis
505 if "psf_kurtosis_cnn" in pred.dtype.names:
506 select_kurtosis = kurtosis_cut(
507 pred["psf_kurtosis_cnn"], kurtosis_lim=kurtosis_lim
508 )
509 set_flag_bit_and_log(
510 ~select_kurtosis, flags, FLAGBIT_KURTOSIS, "stars with accepted kurtosis"
511 )
513 return flags
516def remove_outliers(x, y, n_sigma, list_cols_use="all"):
517 """
518 Remove outliers based on the difference between predicted and actual values.
520 :param x: Predicted values (e.g., CNN predictions).
521 :param y: Actual values (e.g., SExtractor measurements).
522 :param n_sigma: Number of standard deviations for clipping.
523 :param list_cols_use: List of columns to use for outlier detection, or 'all'
524 to use all columns.
525 :return: Boolean array indicating which samples are not outliers.
526 """
528 n_samples, n_dim = x.shape
530 if list_cols_use == "all": 530 ↛ 531line 530 didn't jump to line 531 because the condition on line 530 was never true
531 list_cols_use = range(n_dim)
533 select_clip = np.ones(n_samples, dtype=bool)
535 res = y - x
537 for i_dim in list_cols_use:
538 _, lo, up = scipy.stats.sigmaclip(res[:, i_dim], low=n_sigma, high=n_sigma)
539 select_clip &= (res[:, i_dim] < up) & (res[:, i_dim] > lo)
541 return select_clip
544def get_match_vector(cat, cat_gaia, max_dist_arcsec):
545 """
546 Match the SExtractor catalog with the GAIA catalog based on RA/Dec coordinates.
547 The matching is done using a nearest neighbor search with a maximum distance
548 defined by max_dist_arcsec.
550 :param cat: SExtractor catalog data.
551 :param cat_gaia: GAIA catalog data.
552 :param max_dist_arcsec: Maximum distance in arcseconds for matching.
553 :return: Tuple of boolean selection array and matched RA/Dec coordinates
554 from GAIA."""
556 gaia_ang = np.concatenate(
557 [
558 cat_gaia["dec"][:, np.newaxis] * np.pi / 180.0,
559 cat_gaia["ra"][:, np.newaxis] * np.pi / 180.0,
560 ],
561 axis=1,
562 )
563 cat_ang = np.concatenate(
564 [
565 cat["DELTAWIN_J2000"][:, np.newaxis] * np.pi / 180.0,
566 cat["ALPHAWIN_J2000"][:, np.newaxis] * np.pi / 180.0,
567 ],
568 axis=1,
569 )
571 # calculate nearest neighbours
572 ball_tree = BallTree(gaia_ang, metric="haversine")
573 dist, ind = ball_tree.query(cat_ang, k=1)
574 dist_arcsec = dist[:, 0] / np.pi * 180.0 * 3600.0
575 select_in_gaia = dist_arcsec < max_dist_arcsec
577 # get matched
578 ind_match = ind[:, 0][select_in_gaia]
580 # remove those with conflicts
581 un, ui, uc = np.unique(ind_match, return_counts=True, return_inverse=True)
582 select_unique_matches = uc[ui] == 1
583 select_in_gaia[select_in_gaia] &= select_unique_matches
584 ind_match = ind_match[select_unique_matches]
586 # assign
587 gaia_ra_match = np.full_like(cat["ALPHAWIN_J2000"], -200)
588 gaia_dec_match = np.full_like(cat["DELTAWIN_J2000"], -200)
589 gaia_id_match = np.full(len(cat), -200, dtype=int)
590 gaia_ra_match[select_in_gaia] = cat_gaia["ra"][ind_match]
591 gaia_dec_match[select_in_gaia] = cat_gaia["dec"][ind_match]
593 gaia_id_match[select_in_gaia] = cat_gaia["id"][ind_match]
595 return select_in_gaia, gaia_ra_match, gaia_dec_match, gaia_id_match