Coverage for src/ufig/psf_estimation/cnn_predictions.py: 98%
53 statements
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
1# Copyright (C) 2025 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
4# created: Mon Jul 21 2025
7import numpy as np
8from cosmic_toolbox import arraytools as at
9from cosmic_toolbox import logger
11from . import cnn_util, psf_utils, star_sample_selection_cnn
13LOGGER = logger.get_logger(__file__)
14ERR_VAL = 999.0
17class CNNPredictorPSF:
18 """
19 Handles CNN-based PSF parameter prediction.
21 This class manages:
22 - CNN model loading and configuration
23 - Batch prediction on star stamp images
24 - Post-processing and quality assessment of predictions
25 - Application of systematic corrections
26 """
28 def __init__(self, config):
29 """
30 Initialize the PSF predictor with configuration parameters.
32 :param config: Configuration dictionary with parameters for PSF prediction.
33 """
34 self.config = config
35 self.cnn_pred = None
37 def predict_psf_parameters(self, cat_gaia, cube_gaia, filepath_cnn):
38 """
39 Predict PSF parameters for a catalog of stars using a CNN model.
41 :param cat_gaia: Input catalog of stars with GAIA matching.
42 :param cube_gaia: Cube of star stamp images.
43 :param filepath_cnn: Path to the CNN model file.
44 :return: Dictionary with predicted PSF parameters and quality flags.
45 """
47 LOGGER.info("Starting CNN PSF parameter prediction")
49 # Load CNN model
50 self._load_cnn_model(filepath_cnn)
52 # Select stars that pass quality cuts for CNN processing
53 cat_cnn = cat_gaia
54 cube_cnn = cube_gaia
55 flags_cnn = np.zeros(len(cat_cnn), dtype=np.int32)
57 # Run CNN predictions
58 cnn_predictions = self._run_cnn_inference(cube_cnn)
60 # Process and merge predictions
61 cat_cnn = self._process_predictions(cat_cnn, cnn_predictions)
63 # Apply corrections and additional quality cuts
64 cat_cnn, flags_cnn = self._apply_corrections_and_cuts(cat_cnn, flags_cnn)
66 LOGGER.info(
67 f"CNN prediction complete. Valid predictions: {np.sum(flags_cnn == 0)}"
68 )
70 return {
71 "cat_cnn": cat_cnn,
72 "flags_cnn": flags_cnn,
73 "cube_cnn": cube_cnn,
74 "col_names_cnn": self.cnn_pred.config["param_names"],
75 }
77 def _load_cnn_model(self, filepath_cnn):
78 """Load the CNN model for PSF parameter prediction."""
79 LOGGER.debug(f"Loading CNN model: {filepath_cnn}")
80 self.cnn_pred = cnn_util.CNNPredictor(filepath_cnn)
81 LOGGER.info(
82 f"CNN model loaded. Parameters: {self.cnn_pred.config['param_names']}"
83 )
85 def _run_cnn_inference(self, cube_cnn):
86 """
87 Run CNN inference on the star stamp images.
89 :param cube_cnn: Cube of star stamp images.
90 :return: Array of predicted PSF parameters.
91 """
92 LOGGER.info(f"Running CNN inference on {len(cube_cnn)} stars")
94 # Run CNN prediction in batches
95 cnn_predictions = self.cnn_pred(cube_cnn, batchsize=500)
97 LOGGER.debug(f"CNN inference complete. Output shape: {cnn_predictions.shape}")
99 return cnn_predictions
101 def _process_predictions(self, cat_cnn, cnn_predictions):
102 """
103 Process and merge CNN predictions with catalog.
105 :param cat_cnn: Catalog of stars with GAIA matching.
106 :param cnn_predictions: Array of predicted PSF parameters.
107 :return: Updated catalog with CNN predictions.
108 """
109 LOGGER.debug("Processing CNN predictions")
111 # Get parameter names and create CNN column names
112 col_names = self.cnn_pred.config["param_names"]
113 col_names_cnn = [name + "_cnn" for name in col_names]
115 # Convert predictions to structured array
116 cnn_predictions_structured = np.core.records.fromarrays(
117 cnn_predictions.T, names=",".join(col_names_cnn)
118 )
120 # Remove any existing CNN columns to avoid conflicts
121 cat_cnn = at.delete_columns(cat_cnn, cnn_predictions_structured.dtype.names)
123 # Merge predictions with catalog
124 cat_cnn = np.lib.recfunctions.merge_arrays(
125 (cat_cnn, cnn_predictions_structured), flatten=True
126 )
128 # Apply post-processing (unit conversions, etc.)
129 psf_utils.postprocess_catalog(cat_cnn)
131 return cat_cnn
133 def _apply_corrections_and_cuts(self, cat_cnn, flags_cnn):
134 """
135 Apply corrections and quality cuts to the CNN predictions.
137 :param cat_cnn: Catalog of stars with CNN predictions.
138 :param flags_cnn: Flags for the CNN predictions.
139 :return: Updated catalog and flags after corrections and cuts.
140 """
141 LOGGER.debug("Applying corrections and quality cuts")
143 # Apply PSF measurement adjustments if specified
144 psf_measurement_adjustment = self.config.get("psf_measurement_adjustment")
145 if psf_measurement_adjustment is not None:
146 LOGGER.info("Applying PSF measurement adjustments")
147 psf_utils.adjust_psf_measurements(cat_cnn, psf_measurement_adjustment)
149 # Apply Brighter-Fatter correction if specified
150 psfmodel_corr_brighter_fatter = self.config.get("psfmodel_corr_brighter_fatter")
151 if psfmodel_corr_brighter_fatter is not None: 151 ↛ 158line 151 didn't jump to line 158 because the condition on line 151 was always true
152 LOGGER.info("Applying Brighter-Fatter correction")
153 psf_utils.apply_brighter_fatter_correction(
154 cat_cnn, psfmodel_corr_brighter_fatter
155 )
157 # Apply CNN prediction quality cuts
158 flags_cnn = star_sample_selection_cnn.select_cnn_predictions(
159 flags_cnn,
160 cat_cnn,
161 beta_lim=self.config.get("beta_lim", (1.5, 10)),
162 fwhm_lim=self.config.get("fwhm_lim", (1, 10)),
163 ellipticity_lim=self.config.get("ellipticity_lim", (-0.3, 0.3)),
164 flexion_lim=self.config.get("flexion_lim", (-0.3, 0.3)),
165 kurtosis_lim=self.config.get("kurtosis_lim", (-1, 1)),
166 )
168 n_good = np.sum(flags_cnn == 0)
169 LOGGER.info(f"After corrections and cuts: {n_good} valid predictions")
171 return cat_cnn, flags_cnn