Coverage for src/ufig/psf_estimation/psf_predictions.py: 92%

77 statements  

« prev     ^ index     » next       coverage.py v7.10.2, created at 2025-08-07 15:17 +0000

1# Copyright (C) 2025 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4# created: Mon Jul 28 2025 

5 

6import h5py 

7import numpy as np 

8from cosmic_toolbox import arraytools as at 

9from cosmic_toolbox import logger 

10 

11from . import correct_brighter_fatter, psf_utils 

12from .tiled_regressor import TiledRobustPolynomialRegressor as Regressor 

13 

14LOGGER = logger.get_logger(__file__) 

15ERR_VAL = 999.0 

16 

17 

18def colnames_derivative(cols, ax): 

19 """ 

20 Create column names for derivatives with respect to an axis. 

21 

22 :param cols: List of column names to derive from 

23 :param ax: Axis for the derivative ('x' or 'y') 

24 :return: List of new column names for derivatives 

25 """ 

26 return [f"{c}_dd{ax}" for c in cols] 

27 

28 

29def get_model_derivatives( 

30 cat, filepath_psfmodel, cols, psfmodel_corr_brighter_fatter, delta=1e-2 

31): 

32 """ 

33 Calculate spatial derivatives of PSF model parameters. 

34 

35 :param cat: Input catalog with X_IMAGE, Y_IMAGE columns 

36 :param filepath_psfmodel: Path to the PSF model file 

37 :param cols: List of column names to calculate derivatives for 

38 :param psfmodel_corr_brighter_fatter: Parameters for brighter-fatter correction 

39 :param delta: Step size for finite difference calculation 

40 :return: Catalog with additional columns for derivatives 

41 """ 

42 dims = ["x", "y"] 

43 for d in dims: 

44 col_names_der = colnames_derivative(cols, d) 

45 f = f"{d}_IMAGE".upper() 

46 

47 # Calculate model at position - delta/2 

48 cat_dm = cat.copy() 

49 cat_dm[f] -= delta / 2.0 

50 cat_dm = predict_psf_for_catalogue_storing( 

51 cat_dm, filepath_psfmodel, psfmodel_corr_brighter_fatter 

52 )[0] 

53 

54 # Calculate model at position + delta/2 

55 cat_dp = cat.copy() 

56 cat_dp[f] += delta / 2.0 

57 cat_dp = predict_psf_for_catalogue_storing( 

58 cat_dp, filepath_psfmodel, psfmodel_corr_brighter_fatter 

59 )[0] 

60 

61 # Calculate centered finite difference 

62 cat = at.add_cols(cat, names=col_names_der) 

63 for cv, cd in zip(cols, col_names_der): 

64 cat[cd] = (cat_dp[cv] - cat_dm[cv]) / delta 

65 

66 return cat 

67 

68 

69def predict_psf(position_xy, position_weights, regressor, settings, n_per_chunk=1000): 

70 """ 

71 Predict PSF parameters at given positions using a fitted regressor. 

72 

73 :param position_xy: Array of (x,y) positions, shape (n, 2) 

74 :param position_weights: Weights for each position, shape (n, m) 

75 :param regressor: Fitted regressor object 

76 :param settings: Dictionary with model settings, including scale factors 

77 :param n_per_chunk: Number of samples to process in each batch 

78 :return: Tuple of predicted parameters and a mask for positions with no coverage 

79 """ 

80 position_xy_transformed = psf_utils.transform_forward( 

81 position_xy, scale=settings["scale_pos"] 

82 ) 

83 

84 position_xy_transformed_weights = np.concatenate( 

85 [position_xy_transformed, position_weights], axis=1 

86 ) 

87 

88 position_par_transformed = regressor.predict( 

89 position_xy_transformed_weights, batch_size=n_per_chunk 

90 ) 

91 

92 position_par_post = psf_utils.transform_inverse( 

93 position_par_transformed, settings["scale_par"] 

94 ) 

95 

96 select_no_coverage = (position_weights.sum(axis=1) == 0) | np.any( 

97 ~np.isfinite(position_par_post), axis=1 

98 ) 

99 position_par_post[select_no_coverage] = 0 

100 

101 return position_par_post, select_no_coverage 

102 

103 

104def predict_psf_with_file(position_xy, filepath_psfmodel, id_pointing="all"): 

105 """ 

106 Predict PSF parameters at given positions using a saved PSF model file. 

107 

108 :param position_xy: Array of (x,y) positions, shape (n, 2) 

109 :param filepath_psfmodel: Path to the PSF model file 

110 :param id_pointing: ID of the pointing to use, or 'all' 

111 :return: Generator yielding predicted parameters and number of exposures 

112 """ 

113 if position_xy.shape[1] != 2: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true

114 raise ValueError( 

115 f"Invalid position_xy shape (should be n_obj x 2) {position_xy.shape}" 

116 ) 

117 

118 # Setup interpolator 

119 with h5py.File(filepath_psfmodel, "r") as fh5: 

120 par_names = at.set_loading_dtypes(fh5["par_names"][...]) 

121 pointings_maps = fh5["map_pointings"] 

122 position_weights = psf_utils.get_position_weights( 

123 position_xy[:, 0], position_xy[:, 1], pointings_maps 

124 ) 

125 poly_coeffs = fh5["arr_pointings_polycoeffs"][...] 

126 unseen_pointings = fh5["unseen_pointings"][...] 

127 settings = { 

128 key: at.set_loading_dtypes(fh5["settings"][key][...]) 

129 for key in fh5["settings"] 

130 } 

131 settings.setdefault("polynomial_type", "chebyshev") 

132 LOGGER.debug(f"polynomial_type={settings['polynomial_type']}") 

133 

134 # Add debugging information 

135 if "scale_par" in settings: 135 ↛ 139line 135 didn't jump to line 139

136 LOGGER.debug(f"Loaded scale_par shape: {settings['scale_par'].shape}") 

137 LOGGER.debug(f"First few scale_par values: {settings['scale_par'][:3]}") 

138 

139 regressor = Regressor( 

140 poly_order=settings["poly_order"], 

141 ridge_alpha=settings["ridge_alpha"], 

142 polynomial_type=settings["polynomial_type"], 

143 poly_coefficients=poly_coeffs, 

144 unseen_pointings=unseen_pointings, 

145 ) 

146 

147 if id_pointing == "all": 147 ↛ 164line 147 didn't jump to line 164 because the condition on line 147 was always true

148 LOGGER.info( 

149 f"prediction for cnn models n_pos={position_xy.shape[0]} id_pointing=all" 

150 ) 

151 

152 position_par_post, select_no_coverage = predict_psf( 

153 position_xy, position_weights, regressor, settings 

154 ) 

155 

156 position_par_post = np.core.records.fromarrays( 

157 position_par_post.T, names=",".join(par_names) 

158 ) 

159 psf_utils.postprocess_catalog(position_par_post) 

160 n_exposures = psf_utils.position_weights_to_nexp(position_weights) 

161 yield position_par_post, select_no_coverage, n_exposures 

162 

163 else: 

164 raise NotImplementedError( 

165 "This feature is not yet implemented due to the polynomial coefficients" 

166 " covariances which are a tiny bit tricky but definitely doable" 

167 ) 

168 

169 

170def predict_psf_for_catalogue( 

171 cat, filepath_psfmodel, id_pointing="all", psfmodel_corr_brighter_fatter=None 

172): 

173 """ 

174 Predict PSF parameters for a given catalog. 

175 

176 :param cat: Input catalog with X_IMAGE, Y_IMAGE (and MAG_AUTO if applicable) columns 

177 :param filepath_psfmodel: Path to the PSF model file 

178 :param id_pointing: ID of the pointing to use, or 'all' 

179 :param psfmodel_corr_brighter_fatter: Parameters for brighter-fatter correction 

180 :return: Generator yielding predicted parameters and number of exposures 

181 """ 

182 position_xy = np.stack((cat["X_IMAGE"] - 0.5, cat["Y_IMAGE"] - 0.5), axis=-1) 

183 

184 for position_par_post, select_no_coverage, n_exposures in predict_psf_with_file( 184 ↛ exitline 184 didn't return from function 'predict_psf_for_catalogue' because the loop on line 184 didn't complete

185 position_xy, filepath_psfmodel, id_pointing=id_pointing 

186 ): 

187 # Set points without coverage to error value 

188 for par_name in position_par_post.dtype.names: 

189 position_par_post[par_name][select_no_coverage] = ERR_VAL 

190 

191 # Apply brighter-fatter correction 

192 if ( 

193 psfmodel_corr_brighter_fatter is not None 

194 and "apply_to_galaxies" in psfmodel_corr_brighter_fatter 

195 and psfmodel_corr_brighter_fatter["apply_to_galaxies"] 

196 ): 

197 LOGGER.info("added brighter-fatter to PSF prediction") 

198 ( 

199 position_par_post["psf_fwhm_ipt"], 

200 position_par_post["psf_e1_ipt"], 

201 position_par_post["psf_e2_ipt"], 

202 ) = correct_brighter_fatter.brighter_fatter_add( 

203 col_mag=cat["MAG_AUTO"], 

204 col_fwhm=position_par_post["psf_fwhm_ipt"], 

205 col_e1=position_par_post["psf_e1_ipt"], 

206 col_e2=position_par_post["psf_e2_ipt"], 

207 dict_corr=psfmodel_corr_brighter_fatter, 

208 ) 

209 

210 yield position_par_post, n_exposures 

211 

212 

213def predict_psf_for_catalogue_storing( 

214 cat_in, filepath_psf_model, psfmodel_corr_brighter_fatter 

215): 

216 """ 

217 Predict PSF parameters for a catalog and format for storage. 

218 

219 :param cat_in: Input catalog with X_IMAGE, Y_IMAGE columns 

220 :param filepath_psf_model: Path to the PSF model file 

221 :param psfmodel_corr_brighter_fatter: Parameters for brighter-fatter correction 

222 :return: Tuple of output catalog with predicted parameters and number of exposures 

223 """ 

224 # Predict using the generator function 

225 position_par_post, n_exposures = next( 

226 predict_psf_for_catalogue( 

227 cat=cat_in, 

228 filepath_psfmodel=filepath_psf_model, 

229 psfmodel_corr_brighter_fatter=psfmodel_corr_brighter_fatter, 

230 ) 

231 ) 

232 

233 # Create output catalog with proper structure 

234 if "id" in cat_in.dtype.names: 234 ↛ 236line 234 didn't jump to line 236 because the condition on line 234 was never true

235 # if it is a simulated catalog, we also want to add the id of the galaxy 

236 dtypes = at.get_dtype( 

237 ("X_IMAGE", "Y_IMAGE", "id") + position_par_post.dtype.names 

238 ) 

239 else: 

240 dtypes = at.get_dtype(("X_IMAGE", "Y_IMAGE") + position_par_post.dtype.names) 

241 

242 cat_out = np.empty(len(cat_in), dtype=dtypes) 

243 

244 # Copy predicted parameters 

245 for par_name in position_par_post.dtype.names: 

246 cat_out[par_name] = position_par_post[par_name] 

247 

248 # Copy position information 

249 cat_out["X_IMAGE"] = cat_in["X_IMAGE"] 

250 cat_out["Y_IMAGE"] = cat_in["Y_IMAGE"] 

251 

252 return cat_out, n_exposures