Coverage for src/ufig/psf_estimation/psf_utils.py: 95%

89 statements  

« prev     ^ index     » next       coverage.py v7.10.2, created at 2025-08-07 15:17 +0000

1import os 

2 

3import h5py 

4import numpy as np 

5from cosmic_toolbox import arraytools as at 

6from cosmic_toolbox import logger 

7 

8from ufig import mask_utils 

9from ufig.psf_estimation import correct_brighter_fatter 

10 

11LOGGER = logger.get_logger(__file__) 

12 

13ERR_VAL = 999 

14 

15 

16def transform_forward(vec, scale): 

17 vec_transformed = (vec - scale[:, 0]) / scale[:, 1] 

18 return vec_transformed 

19 

20 

21def transform_inverse(vec_transformed, scale): 

22 vec = vec_transformed * scale[:, 1] + scale[:, 0] 

23 return vec 

24 

25 

26def position_weights_to_nexp(position_weights): 

27 """ 

28 Transform position weights to number of exposures. 

29 """ 

30 n_exp = np.sum(position_weights > 0, axis=1).astype(np.uint16) 

31 return n_exp 

32 

33 

34def postprocess_catalog(cat): 

35 """ 

36 Post-process the catalog after PSF prediction. 

37 This function ensures that the PSF flux ratio is within valid bounds. 

38 

39 :param cat: Input catalog with PSF parameters. 

40 :return: None, modifies the catalog in place. 

41 """ 

42 if "psf_flux_ratio_cnn" in cat.dtype.names: 

43 cat["psf_flux_ratio_cnn"] = np.clip( 

44 cat["psf_flux_ratio_cnn"], a_min=0.0, a_max=1.0 

45 ) 

46 

47 

48def get_position_weights(x, y, pointings_maps): 

49 """ 

50 Get the position weights for each star based on the number of exposures 

51 at each pixel location. 

52 

53 :param x: x-coordinates of stars 

54 :param y: y-coordinates of stars 

55 :param pointings_maps: bitmaps indicating exposure coverage 

56 :return: position weights for each star 

57 """ 

58 size_y, size_x = pointings_maps.shape 

59 

60 x_noedge = x.astype(np.int32) 

61 y_noedge = y.astype(np.int32) 

62 x_noedge[x_noedge >= size_x] = size_x - 1 

63 y_noedge[y_noedge >= size_y] = size_y - 1 

64 x_noedge[x_noedge < 0] = 0 

65 y_noedge[y_noedge < 0] = 0 

66 

67 n_pointings = pointings_maps.attrs["n_pointings"] 

68 

69 n_bit = 64 

70 

71 position_weights = mask_utils.decimal_integer_to_binary( 

72 n_bit, pointings_maps["bit1"][y_noedge, x_noedge], dtype_out=np.float64 

73 ) 

74 

75 for n in range(2, 6): 75 ↛ 93line 75 didn't jump to line 93 because the loop on line 75 didn't complete

76 n_pointings -= 64 

77 if n_pointings > 0: 

78 position_weights = np.concatenate( 

79 ( 

80 position_weights, 

81 mask_utils.decimal_integer_to_binary( 

82 n_bit, 

83 pointings_maps[f"bit{str(n)}"][y_noedge, x_noedge], 

84 dtype_out=np.float64, 

85 ), 

86 ), 

87 axis=1, 

88 dtype=np.float64, 

89 ) 

90 else: 

91 break 

92 

93 norm = np.sum(np.array(position_weights), axis=1, keepdims=True) 

94 position_weights /= norm 

95 position_weights[norm[:, 0] == 0] = 0 

96 

97 return position_weights 

98 

99 

100def get_star_cube_filename(filepath_cat_out): 

101 root, ext = os.path.splitext(filepath_cat_out) 

102 filename_cube = root + "_starcube.h5" 

103 return filename_cube 

104 

105 

106def write_star_cube(star_cube, cat, filepath_cat_out): 

107 filename_cube = get_star_cube_filename(filepath_cat_out) 

108 with h5py.File(filename_cube, "w") as fh5: 

109 fh5.create_dataset( 

110 name="star_cube", data=at.set_storing_dtypes(star_cube), compression="lzf" 

111 ) 

112 fh5.create_dataset( 

113 name="cat", data=at.set_storing_dtypes(cat), compression="lzf" 

114 ) 

115 LOGGER.info( 

116 f"created star cube file {filename_cube} with {len(cat)} stars, " 

117 f"size {star_cube.nbytes / 1024**2:.2f} MB" 

118 ) 

119 

120 

121def apply_brighter_fatter_correction(cat_cnn, psfmodel_corr_brighter_fatter): 

122 """ 

123 Apply Brighter-Fatter correction to the PSF parameters in the catalog. 

124 """ 

125 

126 mean_size_before = np.mean(cat_cnn["psf_fwhm_cnn"]) 

127 mean_e1_before = np.mean(cat_cnn["psf_e1_cnn"]) 

128 mean_e2_before = np.mean(cat_cnn["psf_e2_cnn"]) 

129 

130 ( 

131 cat_cnn["psf_fwhm_cnn"], 

132 cat_cnn["psf_e1_cnn"], 

133 cat_cnn["psf_e2_cnn"], 

134 ) = correct_brighter_fatter.brighter_fatter_remove( 

135 col_mag=cat_cnn["MAG_AUTO"], 

136 col_fwhm=cat_cnn["psf_fwhm_cnn"], 

137 col_e1=cat_cnn["psf_e1_cnn"], 

138 col_e2=cat_cnn["psf_e2_cnn"], 

139 dict_corr=psfmodel_corr_brighter_fatter, 

140 ) 

141 

142 mean_size_after = np.mean(cat_cnn["psf_fwhm_cnn"]) 

143 mean_e1_after = np.mean(cat_cnn["psf_e1_cnn"]) 

144 mean_e2_after = np.mean(cat_cnn["psf_e2_cnn"]) 

145 LOGGER.info( 

146 f"applied brighter-fatter correction, difference in mean " 

147 f"fwhm={mean_size_after - mean_size_before:2.5f} " 

148 f"e1={mean_e1_after - mean_e1_before:2.5f} " 

149 f"e2={mean_e2_after - mean_e2_before:2.5f}" 

150 ) 

151 

152 

153def adjust_psf_measurements(cat, psf_measurement_adjustment=None): 

154 """ 

155 Adjust PSF measurements based on provided adjustment parameters. 

156 

157 :param cat: Input catalog with PSF parameters. 

158 :param psf_measurement_adjustment: Dictionary with adjustment parameters 

159 for each PSF measurement. 

160 """ 

161 

162 if psf_measurement_adjustment is None: 162 ↛ 163line 162 didn't jump to line 163 because the condition on line 162 was never true

163 LOGGER.info("NO PSF parameter adjustment") 

164 

165 else: 

166 adjust = lambda x, a: x * a[1] + a[0] # noqa: E731 

167 for c in psf_measurement_adjustment: 

168 c_cat = c + "_cnn" if c not in cat.dtype.names else c 

169 col_adj = adjust(cat[c_cat], psf_measurement_adjustment[c]) 

170 LOGGER.warning( 

171 f"PSF parameter adjustment {c}: mean frac diff after " 

172 f"adjustment: {np.mean((col_adj-cat[c_cat]) / cat[c_cat]):2.4f}" 

173 ) 

174 cat[c_cat] = col_adj 

175 

176 

177class PSFEstError(ValueError): 

178 """ 

179 Raised when too few stars for PSF estimation were found. 

180 """ 

181 

182 

183def write_empty_output(filepath_out, filepath_cat_out=None, save_star_cube=False): 

184 """ 

185 Creates empty output files when PSF estimation fails. 

186 

187 This function creates placeholder files to ensure that even when 

188 PSF estimation fails, the expected output files exist, preventing 

189 downstream processes from failing due to missing files. 

190 

191 Parameters 

192 ---------- 

193 filepath_out : str 

194 Path to the main PSF model output file 

195 filepath_cat_out : str, optional 

196 Path to the catalog output file 

197 save_star_cube : bool, optional 

198 Whether a star cube file was expected 

199 """ 

200 LOGGER.warning("Creating empty output files for failed PSF estimation") 

201 

202 # Create empty main PSF model file 

203 write_empty_file(filepath_out) 

204 

205 # Create empty catalog output if requested 

206 if filepath_cat_out is not None: 206 ↛ 214line 206 didn't jump to line 214 because the condition on line 206 was always true

207 write_empty_file(filepath_cat_out) 

208 

209 # Create empty star cube if requested 

210 if save_star_cube: 210 ↛ 214line 210 didn't jump to line 214 because the condition on line 210 was always true

211 star_cube_path = get_star_cube_filename(filepath_cat_out) 

212 write_empty_file(star_cube_path) 

213 

214 LOGGER.info("Created empty output files") 

215 

216 

217def write_empty_file(path): 

218 with h5py.File(path, mode="w"): 

219 pass 

220 

221 

222def select_validation_stars(n_stars, fraction_validation): 

223 """ 

224 Select validation stars for PSF model testing. 

225 

226 :param n_stars: Total number of stars in the catalog. 

227 :param fraction_validation: Fraction of stars to select for validation. 

228 :return: Array of indices for validation stars. 

229 """ 

230 n_validation = int(n_stars * fraction_validation) 

231 if n_validation == 0: 

232 return np.array([], dtype=int) 

233 

234 # Use deterministic selection based on star indices 

235 # This ensures reproducible validation star selection 

236 np.random.seed(42) # Fixed seed for reproducibility 

237 indices = np.random.choice(n_stars, size=n_validation, replace=False) 

238 

239 return indices