Coverage for src/ufig/psf_estimation/data_preparation.py: 100%
55 statements
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
1# Copyright (C) 2025 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
4# created: Mon Jul 21 2025
7import h5py
8import numpy as np
9from astropy.io import fits
10from cosmic_toolbox import arraytools as at
11from cosmic_toolbox import logger
13from ufig.array_util import set_flag_bit
15from ..se_moment_util import get_se_cols
16from . import psf_utils, star_sample_selection_cnn
18LOGGER = logger.get_logger(__file__)
21class PSFDataPreparator:
22 """
23 Handles data preparation for PSF estimation.
25 This class is responsible for:
26 - Loading and preprocessing image data
27 - GAIA catalog matching for stellar sample selection
28 - Initial quality cuts and flag management
29 - Preparation of data structures for CNN processing
30 """
32 def __init__(self, config):
33 """
34 Initialize the data preparator with configuration parameters.
36 :param config: Configuration dictionary with parameters for data preparation.
37 """
38 self.config = config
40 def prepare_data(
41 self,
42 filepath_image,
43 filepath_sexcat,
44 filepath_sysmaps,
45 filepath_gaia,
46 ):
47 """
48 Prepare the data for the PSF estimation. This includes
49 - loading the image and the SExtractor catalog,
50 - matching with GAIA catalog,
51 - applying initial quality cuts.
53 :param filepath_image: Path to the input image file to estimate the PSF from.
54 :param filepath_sexcat: Path to the SExtractor catalog file of the image.
55 :param filepath_sysmaps: Path to the systematics maps file.
56 :param filepath_gaia: Path to the GAIA catalog file.
57 :param config: Configuration dictionary with parameters for data preparation.
58 """
59 LOGGER.info("Starting data preparation")
61 # Load image data and catalog
62 img = self._load_image(filepath_image)
63 cat = at.load_hdf_cols(filepath_sexcat)
65 # Match with GAIA and apply initial cuts
66 cat_gaia, flags_all = self._select_gaia_stars(
67 cat, filepath_gaia, filepath_image
68 )
70 # Get exposure information and position weights
71 position_weights, pointings_maps = self._get_exposure_info(
72 cat_gaia, filepath_sysmaps
73 )
75 # Apply moment measurements and quality cuts for CNN
76 cat_gaia, flags_gaia, cube_gaia = self._prepare_cnn_sample(
77 cat_gaia, img, position_weights, pointings_maps
78 )
80 LOGGER.info(
81 f"Data preparation complete. Stars for CNN: {np.sum(flags_gaia == 0)}"
82 )
84 return {
85 "cat_original": cat, # Full SExtractor catalog for PSF predictions
86 "cat_gaia": cat_gaia,
87 "flags_all": flags_all,
88 "flags_gaia": flags_gaia,
89 "cube_gaia": cube_gaia,
90 "position_weights": position_weights,
91 "image_shape": img.shape,
92 }
94 def _load_image(self, filepath_image):
95 """Load astronomical image from FITS file."""
96 LOGGER.debug(f"Loading image: {filepath_image}")
97 img = np.array(fits.getdata(filepath_image), dtype=float)
98 LOGGER.info(f"Loaded image with shape: {img.shape}")
99 return img
101 def _select_gaia_stars(self, cat, filepath_gaia, filepath_image):
102 """
103 Select the objects from the SExtractor catalog that are in the GAIA catalog.
105 :param cat: SExtractor catalog data.
106 :param filepath_gaia: Path to the GAIA catalog file.
107 :param filepath_image: Path to the image file for coordinate conversion.
108 :param config: Configuration dictionary with parameters for data preparation.
109 """
110 LOGGER.info("Matching with GAIA catalog")
112 # Initialize flags for all sources
113 flags_all = np.zeros(len(cat), dtype=np.int32)
115 # Cross-match with GAIA
116 cat = star_sample_selection_cnn.get_gaia_match(
117 cat, filepath_gaia, self.config["max_dist_gaia_arcsec"]
118 )
120 # Convert GAIA coordinates to image pixels
121 cat = star_sample_selection_cnn.get_gaia_image_coords(filepath_image, cat)
123 # Add astrometric differences if requested
124 if self.config.get("astrometry_errors", False):
125 cat = self._add_astrometry_diff(cat)
127 # Select GAIA matches
128 select_gaia = cat["match_gaia"].astype(bool)
130 # Flag non-GAIA sources
131 set_flag_bit(
132 flags=flags_all,
133 select=~select_gaia,
134 field=star_sample_selection_cnn.FLAGBIT_GAIA,
135 )
137 LOGGER.info(
138 f"GAIA matching: {np.count_nonzero(select_gaia)}/{len(flags_all)} "
139 f"sources matched"
140 )
142 # Keep only GAIA-matched sources
143 cat_gaia = cat[select_gaia]
145 return cat_gaia, flags_all
147 def _get_exposure_info(self, cat_gaia, filepath_sys):
148 """
149 Get exposure information and calculate position weights.
151 Returns:
152 tuple: (position_weights, pointings_maps)
153 """
154 LOGGER.debug("Getting exposure information")
156 with h5py.File(filepath_sys, mode="r") as fh5_maps:
157 pointings_maps = fh5_maps["map_pointings"]
159 # Calculate position weights based on exposure coverage
160 position_weights = psf_utils.get_position_weights(
161 x=cat_gaia["XWIN_IMAGE"] - 0.5, # Convert to 0-indexed
162 y=cat_gaia["YWIN_IMAGE"] - 0.5,
163 pointings_maps=pointings_maps,
164 )
166 # Copy pointing maps data for later use
167 pointings_maps_data = pointings_maps[...]
169 return position_weights, pointings_maps_data
171 def _prepare_cnn_sample(self, cat_gaia, img, position_weights, pointings_maps):
172 """
173 Prepare the sample for CNN processing. Enrich catalog with moment measurements
174 and apply quality cuts.
176 :param cat_gaia: Catalog with GAIA matches.
177 :param img: Image data for cutout extraction.
178 :param position_weights: Position weights for each star.
179 :param pointings_maps: Pointing maps for exposure coverage.
181 """
182 LOGGER.debug("Preparing CNN sample")
184 # Add moment measurements
185 cat_gaia = get_se_cols(cat_gaia)
187 # Apply quality cuts and extract star stamps
188 flags_gaia, cube_gaia = star_sample_selection_cnn.get_stars_for_cnn(
189 cat=cat_gaia,
190 image=img,
191 star_stamp_shape=self.config.get("star_stamp_shape", (19, 19)),
192 pointings_maps=pointings_maps,
193 position_weights=position_weights,
194 star_mag_range=self.config.get("star_mag_range", (18, 22)),
195 min_n_exposures=self.config.get("min_n_exposures", 1),
196 sextractor_flags=self.config.get("sextractor_flags", [0, 16]),
197 flag_coadd_boundaries=self.config.get("flag_coadd_boundaries", True),
198 moments_lim=self.config.get("moments_lim", (-99, 99)),
199 )
201 return cat_gaia, flags_gaia, cube_gaia
203 def _add_astrometry_diff(self, cat):
204 """
205 Add astrometric difference measurements.
206 """
207 cat = at.ensure_cols(cat, names=["astrometry_diff_x", "astrometry_diff_y"])
208 cat["astrometry_diff_x"] = cat["X_IMAGE"] - cat["gaia_x_match"]
209 cat["astrometry_diff_y"] = cat["Y_IMAGE"] - cat["gaia_y_match"]
210 return cat