Coverage for src/ufig/psf_estimation/cnn_predictions.py: 98%

53 statements  

« prev     ^ index     » next       coverage.py v7.10.2, created at 2025-08-07 15:17 +0000

1# Copyright (C) 2025 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4# created: Mon Jul 21 2025 

5 

6 

7import numpy as np 

8from cosmic_toolbox import arraytools as at 

9from cosmic_toolbox import logger 

10 

11from . import cnn_util, psf_utils, star_sample_selection_cnn 

12 

13LOGGER = logger.get_logger(__file__) 

14ERR_VAL = 999.0 

15 

16 

17class CNNPredictorPSF: 

18 """ 

19 Handles CNN-based PSF parameter prediction. 

20 

21 This class manages: 

22 - CNN model loading and configuration 

23 - Batch prediction on star stamp images 

24 - Post-processing and quality assessment of predictions 

25 - Application of systematic corrections 

26 """ 

27 

28 def __init__(self, config): 

29 """ 

30 Initialize the PSF predictor with configuration parameters. 

31 

32 :param config: Configuration dictionary with parameters for PSF prediction. 

33 """ 

34 self.config = config 

35 self.cnn_pred = None 

36 

37 def predict_psf_parameters(self, cat_gaia, cube_gaia, filepath_cnn): 

38 """ 

39 Predict PSF parameters for a catalog of stars using a CNN model. 

40 

41 :param cat_gaia: Input catalog of stars with GAIA matching. 

42 :param cube_gaia: Cube of star stamp images. 

43 :param filepath_cnn: Path to the CNN model file. 

44 :return: Dictionary with predicted PSF parameters and quality flags. 

45 """ 

46 

47 LOGGER.info("Starting CNN PSF parameter prediction") 

48 

49 # Load CNN model 

50 self._load_cnn_model(filepath_cnn) 

51 

52 # Select stars that pass quality cuts for CNN processing 

53 cat_cnn = cat_gaia 

54 cube_cnn = cube_gaia 

55 flags_cnn = np.zeros(len(cat_cnn), dtype=np.int32) 

56 

57 # Run CNN predictions 

58 cnn_predictions = self._run_cnn_inference(cube_cnn) 

59 

60 # Process and merge predictions 

61 cat_cnn = self._process_predictions(cat_cnn, cnn_predictions) 

62 

63 # Apply corrections and additional quality cuts 

64 cat_cnn, flags_cnn = self._apply_corrections_and_cuts(cat_cnn, flags_cnn) 

65 

66 LOGGER.info( 

67 f"CNN prediction complete. Valid predictions: {np.sum(flags_cnn == 0)}" 

68 ) 

69 

70 return { 

71 "cat_cnn": cat_cnn, 

72 "flags_cnn": flags_cnn, 

73 "cube_cnn": cube_cnn, 

74 "col_names_cnn": self.cnn_pred.config["param_names"], 

75 } 

76 

77 def _load_cnn_model(self, filepath_cnn): 

78 """Load the CNN model for PSF parameter prediction.""" 

79 LOGGER.debug(f"Loading CNN model: {filepath_cnn}") 

80 self.cnn_pred = cnn_util.CNNPredictor(filepath_cnn) 

81 LOGGER.info( 

82 f"CNN model loaded. Parameters: {self.cnn_pred.config['param_names']}" 

83 ) 

84 

85 def _run_cnn_inference(self, cube_cnn): 

86 """ 

87 Run CNN inference on the star stamp images. 

88 

89 :param cube_cnn: Cube of star stamp images. 

90 :return: Array of predicted PSF parameters. 

91 """ 

92 LOGGER.info(f"Running CNN inference on {len(cube_cnn)} stars") 

93 

94 # Run CNN prediction in batches 

95 cnn_predictions = self.cnn_pred(cube_cnn, batchsize=500) 

96 

97 LOGGER.debug(f"CNN inference complete. Output shape: {cnn_predictions.shape}") 

98 

99 return cnn_predictions 

100 

101 def _process_predictions(self, cat_cnn, cnn_predictions): 

102 """ 

103 Process and merge CNN predictions with catalog. 

104 

105 :param cat_cnn: Catalog of stars with GAIA matching. 

106 :param cnn_predictions: Array of predicted PSF parameters. 

107 :return: Updated catalog with CNN predictions. 

108 """ 

109 LOGGER.debug("Processing CNN predictions") 

110 

111 # Get parameter names and create CNN column names 

112 col_names = self.cnn_pred.config["param_names"] 

113 col_names_cnn = [name + "_cnn" for name in col_names] 

114 

115 # Convert predictions to structured array 

116 cnn_predictions_structured = np.core.records.fromarrays( 

117 cnn_predictions.T, names=",".join(col_names_cnn) 

118 ) 

119 

120 # Remove any existing CNN columns to avoid conflicts 

121 cat_cnn = at.delete_columns(cat_cnn, cnn_predictions_structured.dtype.names) 

122 

123 # Merge predictions with catalog 

124 cat_cnn = np.lib.recfunctions.merge_arrays( 

125 (cat_cnn, cnn_predictions_structured), flatten=True 

126 ) 

127 

128 # Apply post-processing (unit conversions, etc.) 

129 psf_utils.postprocess_catalog(cat_cnn) 

130 

131 return cat_cnn 

132 

133 def _apply_corrections_and_cuts(self, cat_cnn, flags_cnn): 

134 """ 

135 Apply corrections and quality cuts to the CNN predictions. 

136 

137 :param cat_cnn: Catalog of stars with CNN predictions. 

138 :param flags_cnn: Flags for the CNN predictions. 

139 :return: Updated catalog and flags after corrections and cuts. 

140 """ 

141 LOGGER.debug("Applying corrections and quality cuts") 

142 

143 # Apply PSF measurement adjustments if specified 

144 psf_measurement_adjustment = self.config.get("psf_measurement_adjustment") 

145 if psf_measurement_adjustment is not None: 

146 LOGGER.info("Applying PSF measurement adjustments") 

147 psf_utils.adjust_psf_measurements(cat_cnn, psf_measurement_adjustment) 

148 

149 # Apply Brighter-Fatter correction if specified 

150 psfmodel_corr_brighter_fatter = self.config.get("psfmodel_corr_brighter_fatter") 

151 if psfmodel_corr_brighter_fatter is not None: 151 ↛ 158line 151 didn't jump to line 158 because the condition on line 151 was always true

152 LOGGER.info("Applying Brighter-Fatter correction") 

153 psf_utils.apply_brighter_fatter_correction( 

154 cat_cnn, psfmodel_corr_brighter_fatter 

155 ) 

156 

157 # Apply CNN prediction quality cuts 

158 flags_cnn = star_sample_selection_cnn.select_cnn_predictions( 

159 flags_cnn, 

160 cat_cnn, 

161 beta_lim=self.config.get("beta_lim", (1.5, 10)), 

162 fwhm_lim=self.config.get("fwhm_lim", (1, 10)), 

163 ellipticity_lim=self.config.get("ellipticity_lim", (-0.3, 0.3)), 

164 flexion_lim=self.config.get("flexion_lim", (-0.3, 0.3)), 

165 kurtosis_lim=self.config.get("kurtosis_lim", (-1, 1)), 

166 ) 

167 

168 n_good = np.sum(flags_cnn == 0) 

169 LOGGER.info(f"After corrections and cuts: {n_good} valid predictions") 

170 

171 return cat_cnn, flags_cnn