Coverage for src/ufig/psf_estimation/core.py: 96%

51 statements  

« prev     ^ index     » next       coverage.py v7.10.2, created at 2025-08-07 15:17 +0000

1# Copyright (C) 2024 ETH Zurich, Institute for Astronomy 

2 

3""" 

4Core PSF estimation pipeline functionality. 

5 

6This module contains the main PSF estimation pipeline that orchestrates 

7the complete process from image loading to model creation. 

8""" 

9 

10from cosmic_toolbox import logger 

11 

12from . import psf_utils 

13 

14LOGGER = logger.get_logger(__file__) 

15ERR_VAL = 999.0 

16 

17 

18class PSFEstimationPipeline: 

19 """ 

20 Main PSF estimation pipeline orchestrator. 

21 

22 This class coordinates the complete PSF estimation process: 

23 1. Data loading and preparation 

24 2. Star selection and quality cuts 

25 3. CNN-based PSF parameter prediction 

26 4. Polynomial interpolation model fitting 

27 5. Model validation and output generation 

28 """ 

29 

30 def __init__(self, **kwargs): 

31 """ 

32 Initialize the PSF estimation pipeline. 

33 

34 Args: 

35 config (dict): Configuration parameters for the pipeline 

36 """ 

37 self._setup_config(**kwargs) 

38 self.position_weights = None 

39 

40 def create_psf_model( 

41 self, 

42 filepath_image, 

43 filepath_sexcat, 

44 filepath_sysmaps, 

45 filepath_gaia, 

46 filepath_cnn, 

47 filepath_out_model, 

48 filepath_out_cat=None, 

49 ): 

50 """ 

51 Estimates the PSF of the image and saves all necessary files for a later image 

52 simulation. 

53 

54 :param filepath_image: Path to the input image file to estimate the PSF from. 

55 :param filepath_sexcat: Path to the SExtractor catalog file of the image. 

56 :param filepath_sysmaps: Path to the systematics maps file. 

57 :param filepath_gaia: Path to the Gaia catalog file. 

58 :param filepath_cnn: Path to the pretrained CNN model 

59 :param filepath_out_model: Path to save the output PSF model file. 

60 :param filepath_cat_out: Path to save the enriched sextractor catalog, 

61 if None, the catalog at filepath_sexcat will be enriched 

62 """ 

63 try: 

64 # Import dependent modules inside the method to avoid circular imports 

65 from .cnn_predictions import CNNPredictorPSF 

66 from .data_preparation import PSFDataPreparator 

67 from .polynomial_fitting import PolynomialPSFModel 

68 from .save_model import PSFSave 

69 

70 # Step 1: Prepare input data 

71 LOGGER.info("Starting PSF model creation pipeline") 

72 data_prep = PSFDataPreparator(self.config) 

73 processed_data = data_prep.prepare_data( 

74 filepath_image=filepath_image, 

75 filepath_sexcat=filepath_sexcat, 

76 filepath_sysmaps=filepath_sysmaps, 

77 filepath_gaia=filepath_gaia, 

78 ) 

79 

80 # Step 2: Run CNN predictions 

81 # Filter to only include stars that passed CNN quality cuts 

82 select_cnn = processed_data["flags_gaia"] == 0 

83 cat_cnn_input = processed_data["cat_gaia"][select_cnn] 

84 cube_cnn_input = processed_data["cube_gaia"][select_cnn] 

85 self.config["image_shape"] = processed_data["image_shape"] 

86 

87 LOGGER.info( 

88 f"Running CNN predictions on {len(cat_cnn_input)} selected stars" 

89 ) 

90 predictor = CNNPredictorPSF(self.config) 

91 cnn_results = predictor.predict_psf_parameters( 

92 cat_gaia=cat_cnn_input, 

93 cube_gaia=cube_cnn_input, 

94 filepath_cnn=filepath_cnn, 

95 ) 

96 

97 # Step 3: Fit polynomial interpolation model 

98 # Extract position weights for the CNN-selected stars 

99 position_weights_cnn = processed_data["position_weights"][select_cnn] 

100 

101 poly_model = PolynomialPSFModel(self.config) 

102 model_results = poly_model.fit_model( 

103 cat=cnn_results["cat_cnn"], 

104 position_weights=position_weights_cnn, 

105 ) 

106 

107 # Step 4: Generate predictions and save model 

108 io_handler = PSFSave(self.config) 

109 io_handler.save_psf_model( 

110 filepath_out=filepath_out_model, 

111 processed_data=processed_data, 

112 cnn_results=cnn_results, 

113 model_results=model_results, 

114 filepath_sysmaps=filepath_sysmaps, 

115 filepath_cat_out=filepath_out_cat, 

116 ) 

117 

118 LOGGER.info(f"PSF model successfully created: {filepath_out_model}") 

119 

120 except Exception as e: 

121 LOGGER.error(f"PSF model creation failed: {str(e)}") 

122 self._handle_failure(filepath_out_model, filepath_out_cat) 

123 raise 

124 

125 def _setup_config(self, **kwargs): 

126 """ 

127 Creates a configuration dictionary for the pipeline from the kwargs provided. 

128 Additionally, sets default values if arguments are not provided. 

129 

130 :param kwargs: Configuration parameters for the pipeline. 

131 """ 

132 default_ridge_alpha = dict( 

133 psf_flux_ratio_cnn=2.10634454232412, 

134 psf_fwhm_cnn=0.12252798573828638, 

135 psf_e1_cnn=0.5080218046913018, 

136 psf_e2_cnn=0.5080218046913018, 

137 psf_f1_cnn=2.10634454232412, 

138 psf_f2_cnn=2.10634454232412, 

139 psf_g1_cnn=1.311133937421563, 

140 psf_g2_cnn=1.311133937421563, 

141 se_mom_fwhm=0.12, 

142 se_mom_win=0.12, 

143 se_mom_e1=0.51, 

144 se_mom_e2=0.51, 

145 astrometry_diff_x=0.5, 

146 astrometry_diff_y=0.5, 

147 ) 

148 

149 default_corr_bf = { 

150 "c1r": 0.0, 

151 "c1e1": 0.0, 

152 "c1e2": 0.0, 

153 "mag_ref": 22, 

154 "apply_to_galaxies": False, 

155 } 

156 user_corr_bf = kwargs.get("psfmodel_corr_brighter_fatter", {}) 

157 if not isinstance(user_corr_bf, dict): 

158 raise ValueError("psfmodel_corr_brighter_fatter must be a dictionary") 

159 default_corr_bf.update(user_corr_bf) 

160 

161 user_ridge_alpha = kwargs.get("psfmodel_ridge_alpha", {}) 

162 if not isinstance(user_ridge_alpha, dict): 

163 raise ValueError("psfmodel_ridge_alpha must be a dictionary") 

164 default_ridge_alpha.update(user_ridge_alpha) 

165 

166 self.config = { 

167 "astrometry_errors": kwargs.get("astrometry_errors", False), 

168 "max_dist_gaia_arcsec": kwargs.get("max_dist_gaia_arcsec", 0.1), 

169 "cnn_variance_type": kwargs.get("cnn_variance_type", "constant"), 

170 "filepath_cnn_info": kwargs.get("filepath_cnn_info"), 

171 "poly_order": kwargs.get("poly_order", 4), 

172 "polynomial_type": kwargs.get("polynomial_type", "chebyshev"), 

173 "star_mag_range": kwargs.get("star_mag_range", (18, 22)), 

174 "min_n_exposures": kwargs.get("min_n_exposures", 0), 

175 "n_sigma_clip": kwargs.get("n_sigma_clip", 3), 

176 "fraction_validation_stars": kwargs.get("fraction_validation_stars", 0.15), 

177 "save_star_cube": kwargs.get("save_star_cube", False), 

178 "psfmodel_raise_undetermined_error": kwargs.get( 

179 "psfmodel_raise_undetermined_error", False 

180 ), 

181 "star_stamp_shape": kwargs.get("star_stamp_shape", (19, 19)), 

182 "sextractor_flags": kwargs.get("sextractor_flags", [0, 16]), 

183 "flag_coadd_boundaries": kwargs.get("flag_coadd_boundaries", True), 

184 "moments_lim": kwargs.get("moments_lim", (-99, 99)), 

185 "beta_lim": kwargs.get("beta_lim", (1.5, 10)), 

186 "fwhm_lim": kwargs.get("fwhm_lim", (1, 10)), 

187 "ellipticity_lim": kwargs.get("ellipticity_lim", (-0.3, 0.3)), 

188 "flexion_lim": kwargs.get("flexion_lim", (-0.3, 0.3)), 

189 "kurtosis_lim": kwargs.get("kurtosis_lim", (-1, 1)), 

190 "n_max_refit": kwargs.get("n_max_refit", 10), 

191 "psf_measurement_adjustment": kwargs.get("psf_measurement_adjustment"), 

192 "psfmodel_corr_brighter_fatter": default_corr_bf, 

193 "psfmodel_ridge_alpha": default_ridge_alpha, 

194 "precision": kwargs.get("precision", float), 

195 } 

196 

197 def _handle_failure(self, filepath_out, filepath_cat_out=None): 

198 """Handle pipeline failures by creating empty output files.""" 

199 try: 

200 psf_utils.write_empty_output( 

201 filepath_out, 

202 filepath_cat_out=filepath_cat_out, 

203 save_star_cube=self.config.get("save_star_cube", False), 

204 ) 

205 except Exception as cleanup_error: 

206 LOGGER.error(f"Failed to write empty output: {cleanup_error}")