Coverage for src/ufig/psf_estimation/core.py: 96%
51 statements
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
1# Copyright (C) 2024 ETH Zurich, Institute for Astronomy
3"""
4Core PSF estimation pipeline functionality.
6This module contains the main PSF estimation pipeline that orchestrates
7the complete process from image loading to model creation.
8"""
10from cosmic_toolbox import logger
12from . import psf_utils
14LOGGER = logger.get_logger(__file__)
15ERR_VAL = 999.0
18class PSFEstimationPipeline:
19 """
20 Main PSF estimation pipeline orchestrator.
22 This class coordinates the complete PSF estimation process:
23 1. Data loading and preparation
24 2. Star selection and quality cuts
25 3. CNN-based PSF parameter prediction
26 4. Polynomial interpolation model fitting
27 5. Model validation and output generation
28 """
30 def __init__(self, **kwargs):
31 """
32 Initialize the PSF estimation pipeline.
34 Args:
35 config (dict): Configuration parameters for the pipeline
36 """
37 self._setup_config(**kwargs)
38 self.position_weights = None
40 def create_psf_model(
41 self,
42 filepath_image,
43 filepath_sexcat,
44 filepath_sysmaps,
45 filepath_gaia,
46 filepath_cnn,
47 filepath_out_model,
48 filepath_out_cat=None,
49 ):
50 """
51 Estimates the PSF of the image and saves all necessary files for a later image
52 simulation.
54 :param filepath_image: Path to the input image file to estimate the PSF from.
55 :param filepath_sexcat: Path to the SExtractor catalog file of the image.
56 :param filepath_sysmaps: Path to the systematics maps file.
57 :param filepath_gaia: Path to the Gaia catalog file.
58 :param filepath_cnn: Path to the pretrained CNN model
59 :param filepath_out_model: Path to save the output PSF model file.
60 :param filepath_cat_out: Path to save the enriched sextractor catalog,
61 if None, the catalog at filepath_sexcat will be enriched
62 """
63 try:
64 # Import dependent modules inside the method to avoid circular imports
65 from .cnn_predictions import CNNPredictorPSF
66 from .data_preparation import PSFDataPreparator
67 from .polynomial_fitting import PolynomialPSFModel
68 from .save_model import PSFSave
70 # Step 1: Prepare input data
71 LOGGER.info("Starting PSF model creation pipeline")
72 data_prep = PSFDataPreparator(self.config)
73 processed_data = data_prep.prepare_data(
74 filepath_image=filepath_image,
75 filepath_sexcat=filepath_sexcat,
76 filepath_sysmaps=filepath_sysmaps,
77 filepath_gaia=filepath_gaia,
78 )
80 # Step 2: Run CNN predictions
81 # Filter to only include stars that passed CNN quality cuts
82 select_cnn = processed_data["flags_gaia"] == 0
83 cat_cnn_input = processed_data["cat_gaia"][select_cnn]
84 cube_cnn_input = processed_data["cube_gaia"][select_cnn]
85 self.config["image_shape"] = processed_data["image_shape"]
87 LOGGER.info(
88 f"Running CNN predictions on {len(cat_cnn_input)} selected stars"
89 )
90 predictor = CNNPredictorPSF(self.config)
91 cnn_results = predictor.predict_psf_parameters(
92 cat_gaia=cat_cnn_input,
93 cube_gaia=cube_cnn_input,
94 filepath_cnn=filepath_cnn,
95 )
97 # Step 3: Fit polynomial interpolation model
98 # Extract position weights for the CNN-selected stars
99 position_weights_cnn = processed_data["position_weights"][select_cnn]
101 poly_model = PolynomialPSFModel(self.config)
102 model_results = poly_model.fit_model(
103 cat=cnn_results["cat_cnn"],
104 position_weights=position_weights_cnn,
105 )
107 # Step 4: Generate predictions and save model
108 io_handler = PSFSave(self.config)
109 io_handler.save_psf_model(
110 filepath_out=filepath_out_model,
111 processed_data=processed_data,
112 cnn_results=cnn_results,
113 model_results=model_results,
114 filepath_sysmaps=filepath_sysmaps,
115 filepath_cat_out=filepath_out_cat,
116 )
118 LOGGER.info(f"PSF model successfully created: {filepath_out_model}")
120 except Exception as e:
121 LOGGER.error(f"PSF model creation failed: {str(e)}")
122 self._handle_failure(filepath_out_model, filepath_out_cat)
123 raise
125 def _setup_config(self, **kwargs):
126 """
127 Creates a configuration dictionary for the pipeline from the kwargs provided.
128 Additionally, sets default values if arguments are not provided.
130 :param kwargs: Configuration parameters for the pipeline.
131 """
132 default_ridge_alpha = dict(
133 psf_flux_ratio_cnn=2.10634454232412,
134 psf_fwhm_cnn=0.12252798573828638,
135 psf_e1_cnn=0.5080218046913018,
136 psf_e2_cnn=0.5080218046913018,
137 psf_f1_cnn=2.10634454232412,
138 psf_f2_cnn=2.10634454232412,
139 psf_g1_cnn=1.311133937421563,
140 psf_g2_cnn=1.311133937421563,
141 se_mom_fwhm=0.12,
142 se_mom_win=0.12,
143 se_mom_e1=0.51,
144 se_mom_e2=0.51,
145 astrometry_diff_x=0.5,
146 astrometry_diff_y=0.5,
147 )
149 default_corr_bf = {
150 "c1r": 0.0,
151 "c1e1": 0.0,
152 "c1e2": 0.0,
153 "mag_ref": 22,
154 "apply_to_galaxies": False,
155 }
156 user_corr_bf = kwargs.get("psfmodel_corr_brighter_fatter", {})
157 if not isinstance(user_corr_bf, dict):
158 raise ValueError("psfmodel_corr_brighter_fatter must be a dictionary")
159 default_corr_bf.update(user_corr_bf)
161 user_ridge_alpha = kwargs.get("psfmodel_ridge_alpha", {})
162 if not isinstance(user_ridge_alpha, dict):
163 raise ValueError("psfmodel_ridge_alpha must be a dictionary")
164 default_ridge_alpha.update(user_ridge_alpha)
166 self.config = {
167 "astrometry_errors": kwargs.get("astrometry_errors", False),
168 "max_dist_gaia_arcsec": kwargs.get("max_dist_gaia_arcsec", 0.1),
169 "cnn_variance_type": kwargs.get("cnn_variance_type", "constant"),
170 "filepath_cnn_info": kwargs.get("filepath_cnn_info"),
171 "poly_order": kwargs.get("poly_order", 4),
172 "polynomial_type": kwargs.get("polynomial_type", "chebyshev"),
173 "star_mag_range": kwargs.get("star_mag_range", (18, 22)),
174 "min_n_exposures": kwargs.get("min_n_exposures", 0),
175 "n_sigma_clip": kwargs.get("n_sigma_clip", 3),
176 "fraction_validation_stars": kwargs.get("fraction_validation_stars", 0.15),
177 "save_star_cube": kwargs.get("save_star_cube", False),
178 "psfmodel_raise_undetermined_error": kwargs.get(
179 "psfmodel_raise_undetermined_error", False
180 ),
181 "star_stamp_shape": kwargs.get("star_stamp_shape", (19, 19)),
182 "sextractor_flags": kwargs.get("sextractor_flags", [0, 16]),
183 "flag_coadd_boundaries": kwargs.get("flag_coadd_boundaries", True),
184 "moments_lim": kwargs.get("moments_lim", (-99, 99)),
185 "beta_lim": kwargs.get("beta_lim", (1.5, 10)),
186 "fwhm_lim": kwargs.get("fwhm_lim", (1, 10)),
187 "ellipticity_lim": kwargs.get("ellipticity_lim", (-0.3, 0.3)),
188 "flexion_lim": kwargs.get("flexion_lim", (-0.3, 0.3)),
189 "kurtosis_lim": kwargs.get("kurtosis_lim", (-1, 1)),
190 "n_max_refit": kwargs.get("n_max_refit", 10),
191 "psf_measurement_adjustment": kwargs.get("psf_measurement_adjustment"),
192 "psfmodel_corr_brighter_fatter": default_corr_bf,
193 "psfmodel_ridge_alpha": default_ridge_alpha,
194 "precision": kwargs.get("precision", float),
195 }
197 def _handle_failure(self, filepath_out, filepath_cat_out=None):
198 """Handle pipeline failures by creating empty output files."""
199 try:
200 psf_utils.write_empty_output(
201 filepath_out,
202 filepath_cat_out=filepath_cat_out,
203 save_star_cube=self.config.get("save_star_cube", False),
204 )
205 except Exception as cleanup_error:
206 LOGGER.error(f"Failed to write empty output: {cleanup_error}")