Coverage for src/ufig/psf_estimation/cnn_util.py: 96%

124 statements  

« prev     ^ index     » next       coverage.py v7.10.2, created at 2025-08-07 15:17 +0000

1# Copyright (C) 2018 ETH Zurich, Institute for Particle Physics and Astrophysics 

2 

3""" 

4Created on May 25, 2018 

5author: Joerg Herbel 

6Original Source: ethz_des_mccl.psf_estimation.cnn_psf.cnn_util 

7""" 

8 

9import h5py 

10import numpy as np 

11import tensorflow.compat.v1 as tf 

12import yaml 

13from cosmic_toolbox import logger 

14 

15tf.disable_v2_behavior() 

16 

17LOGGER = logger.get_logger(__file__) 

18 

19 

20def res_layer(tensor_in, kernel_shape, activation): 

21 """ 

22 Standard resnet layer (from tiny image net) 

23 """ 

24 with tf.variable_scope("conv_1"): 

25 weights = tf.get_variable( 

26 "weights", 

27 kernel_shape, 

28 initializer=tf.truncated_normal_initializer(0, 0.01), 

29 ) 

30 biases = tf.get_variable( 

31 "biases", kernel_shape[-1], initializer=tf.zeros_initializer() 

32 ) 

33 conv = ( 

34 tf.nn.conv2d(tensor_in, weights, strides=[1, 1, 1, 1], padding="SAME") 

35 + biases 

36 ) 

37 

38 # Batch normalization with proper variable creation to match checkpoint 

39 with tf.variable_scope("batch_normalization"): 

40 # Create variables that match the checkpoint format 

41 mean = tf.get_variable( 

42 "moving_mean", 

43 [kernel_shape[-1]], 

44 initializer=tf.zeros_initializer(), 

45 trainable=False, 

46 ) 

47 variance = tf.get_variable( 

48 "moving_variance", 

49 [kernel_shape[-1]], 

50 initializer=tf.ones_initializer(), 

51 trainable=False, 

52 ) 

53 beta = tf.get_variable( 

54 "beta", [kernel_shape[-1]], initializer=tf.zeros_initializer() 

55 ) 

56 gamma = tf.get_variable( 

57 "gamma", [kernel_shape[-1]], initializer=tf.ones_initializer() 

58 ) 

59 

60 conv = tf.nn.batch_normalization( 

61 conv, 

62 mean=mean, 

63 variance=variance, 

64 offset=beta, 

65 scale=gamma, 

66 variance_epsilon=1e-3, 

67 ) 

68 conv = activation(conv) 

69 

70 with tf.variable_scope("conv_2"): 

71 weights = tf.get_variable( 

72 "weights", 

73 kernel_shape, 

74 initializer=tf.truncated_normal_initializer(0, 0.01), 

75 ) 

76 biases = tf.get_variable( 

77 "biases", 

78 kernel_shape[-1], 

79 initializer=tf.truncated_normal_initializer(0, 0.01), 

80 ) 

81 conv = ( 

82 tf.nn.conv2d(conv, weights, strides=[1, 1, 1, 1], padding="SAME") + biases 

83 ) 

84 

85 # Batch normalization with proper variable creation to match checkpoint 

86 with tf.variable_scope("batch_normalization"): 

87 # Create variables that match the checkpoint format 

88 mean = tf.get_variable( 

89 "moving_mean", 

90 [kernel_shape[-1]], 

91 initializer=tf.zeros_initializer(), 

92 trainable=False, 

93 ) 

94 variance = tf.get_variable( 

95 "moving_variance", 

96 [kernel_shape[-1]], 

97 initializer=tf.ones_initializer(), 

98 trainable=False, 

99 ) 

100 beta = tf.get_variable( 

101 "beta", [kernel_shape[-1]], initializer=tf.zeros_initializer() 

102 ) 

103 gamma = tf.get_variable( 

104 "gamma", [kernel_shape[-1]], initializer=tf.ones_initializer() 

105 ) 

106 

107 conv = tf.nn.batch_normalization( 

108 conv, 

109 mean=mean, 

110 variance=variance, 

111 offset=beta, 

112 scale=gamma, 

113 variance_epsilon=1e-3, 

114 ) 

115 conv = activation(conv + tensor_in) 

116 

117 return conv 

118 

119 

120def create_cnn( 

121 input, 

122 filter_sizes, 

123 n_filters_start, 

124 n_resnet_layers, 

125 resnet_layers_kernel_size, 

126 n_fc, 

127 dropout_rate, 

128 n_out, 

129 activation_function="relu", 

130 apply_dropout=True, 

131 downsampling_method="max_pool", 

132 padding="same", 

133): 

134 activation = getattr(tf.nn, activation_function) 

135 x_tensor = tf.reshape(input, [-1] + input.get_shape().as_list()[1:] + [1]) 

136 

137 # Convolutional layers 

138 current_n_channels = 1 

139 current_height, current_width = x_tensor.get_shape().as_list()[1:3] 

140 x_conv = x_tensor 

141 

142 for layer_ind in range(len(filter_sizes)): 

143 current_n_channels = int(n_filters_start * 2**layer_ind) 

144 

145 if downsampling_method == "off": 145 ↛ 174line 145 didn't jump to line 174 because the condition on line 145 was always true

146 # Use variable names compatible with tf.layers.conv2d 

147 scope_name = "conv2d" if layer_ind == 0 else f"conv2d_{layer_ind}" 

148 

149 with tf.variable_scope(scope_name): 

150 kernel = tf.get_variable( 

151 "kernel", 

152 [ 

153 filter_sizes[layer_ind], 

154 filter_sizes[layer_ind], 

155 x_conv.get_shape()[-1], 

156 current_n_channels, 

157 ], 

158 initializer=tf.truncated_normal_initializer(0, 0.01), 

159 ) 

160 bias = tf.get_variable( 

161 "bias", [current_n_channels], initializer=tf.zeros_initializer() 

162 ) 

163 x_conv = ( 

164 tf.nn.conv2d( 

165 input=x_conv, 

166 filters=kernel, 

167 strides=[1, 1, 1, 1], 

168 padding=padding.upper(), 

169 ) 

170 + bias 

171 ) 

172 x_conv = activation(x_conv) 

173 else: 

174 raise ValueError( 

175 f"Unsupported downsampling method: {downsampling_method}. " 

176 "Currently only 'off' is supported within UFig. For other methods, " 

177 "refer to the original implementation in the ethz_des_mccl repo." 

178 ) 

179 

180 # subtract in case of valid padding 

181 if padding == "valid": 181 ↛ 142line 181 didn't jump to line 142 because the condition on line 181 was always true

182 current_height -= filter_sizes[layer_ind] - 1 

183 current_width -= filter_sizes[layer_ind] - 1 

184 

185 # ResNet layers 

186 resnet_kernel_shape = [ 

187 resnet_layers_kernel_size, 

188 resnet_layers_kernel_size, 

189 current_n_channels, 

190 current_n_channels, 

191 ] 

192 

193 for i_res in range(n_resnet_layers): 

194 with tf.variable_scope(f"resnet_layer_{i_res + 1}"): 

195 x_conv = res_layer(x_conv, resnet_kernel_shape, activation) 

196 

197 # Fully connected layers 

198 x_conv_flat = tf.reshape( 

199 x_conv, [-1, current_height * current_width * current_n_channels] 

200 ) 

201 

202 x_fc = x_conv_flat 

203 for fc_ind in range(len(n_fc)): 

204 # Use variable names compatible with tf.layers.dense 

205 scope_name = "dense" if fc_ind == 0 else f"dense_{fc_ind}" 

206 

207 with tf.variable_scope(scope_name): 

208 kernel = tf.get_variable( 

209 "kernel", 

210 [x_fc.get_shape()[-1], n_fc[fc_ind]], 

211 initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1), 

212 ) 

213 bias = tf.get_variable( 

214 "bias", [n_fc[fc_ind]], initializer=tf.constant_initializer(value=0.1) 

215 ) 

216 x_fc = activation(tf.matmul(x_fc, kernel) + bias) 

217 

218 # Dropout 

219 x_do = tf.nn.dropout(x_fc, rate=dropout_rate if apply_dropout else 0.0) 

220 

221 # Map the fully connected features to output variables 

222 # The output layer is just another dense layer in the checkpoint 

223 final_dense_ind = len(n_fc) 

224 scope_name = "dense" if final_dense_ind == 0 else f"dense_{final_dense_ind}" 

225 

226 with tf.variable_scope(scope_name): 

227 kernel_out = tf.get_variable( 

228 "kernel", 

229 [x_fc.get_shape()[-1], n_out], 

230 initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1), 

231 ) 

232 bias_out = tf.get_variable( 

233 "bias", [n_out], initializer=tf.constant_initializer(value=0.1) 

234 ) 

235 x_out = tf.matmul(x_do, kernel_out) + bias_out 

236 

237 return x_out 

238 

239 

240class CNNPredictor: 

241 def __init__(self, path_trained_cnn): 

242 LOGGER.debug(f"tensorflow {str(tf)} version {tf.__version__}") 

243 

244 self.path_trained_cnn = path_trained_cnn 

245 

246 # Reset graph 

247 tf.reset_default_graph() 

248 

249 # Load configuration 

250 with h5py.File( 

251 get_path_output_config(path_trained_cnn), mode="r" 

252 ) as fh5_config: 

253 self.config = yaml.safe_load(fh5_config.attrs["config"]) 

254 self.config_training_data = yaml.safe_load( 

255 fh5_config.attrs["config_training_data"] 

256 ) 

257 self.input_shape = tuple(fh5_config["input_shape"]) 

258 self.means = fh5_config["means"][...] 

259 self.scales = fh5_config["scales"][...] 

260 

261 self.param_names = self.config["param_names"] 

262 

263 # For backwards compatibility 

264 self.config.setdefault("n_resnet_layers", 0) 

265 self.config.setdefault("resnet_layers_kernel_size", 3) 

266 self.config.setdefault("activation_function", "relu") 

267 self.config.setdefault("padding", "same") 

268 

269 # Get loss function used for training 

270 if "loss_function_kwargs" in self.config: 270 ↛ 273line 270 didn't jump to line 273 because the condition on line 270 was always true

271 self.config["loss_function_kwargs"]["is_training"] = False 

272 

273 n_pred = 2 * len(self.config["param_names"]) 

274 self._transform_predictions = self._apply_means_scales 

275 

276 # Setup network 

277 self.input_tensor = tf.placeholder(tf.float32, (None,) + self.input_shape) 

278 input_tensor_norm = normalize_stamps(self.input_tensor) 

279 self.pred = create_cnn( 

280 input=input_tensor_norm, 

281 filter_sizes=self.config["filter_sizes"], 

282 n_filters_start=self.config["n_filters_start"], 

283 n_resnet_layers=self.config["n_resnet_layers"], 

284 resnet_layers_kernel_size=self.config["resnet_layers_kernel_size"], 

285 n_fc=self.config["n_fully_connected"], 

286 dropout_rate=self.config["dropout_rate"], 

287 n_out=n_pred, 

288 activation_function=self.config["activation_function"], 

289 apply_dropout=False, 

290 downsampling_method=self.config["downsampling_method"], 

291 padding=self.config["padding"], 

292 ) 

293 

294 # Transform predicted parameters 

295 self.par_pred_transformed = self._transform_predictions( 

296 self.pred[:, : len(self.param_names)] 

297 ) 

298 

299 def _apply_means_scales(self, pred): 

300 pred *= self.scales 

301 pred += self.means 

302 return pred 

303 

304 def __call__(self, cube, batchsize=None): 

305 with tf.Session() as sess: 

306 sess.run(tf.global_variables_initializer()) 

307 tf.train.Saver().restore(sess, self.path_trained_cnn) 

308 

309 if batchsize is None or len(cube) == 0: 

310 if len(cube) == 0: 310 ↛ 313line 310 didn't jump to line 313 because the condition on line 310 was always true

311 LOGGER.warning("Predicting on empty cube, output will be empty!") 

312 

313 pred = self.pred.eval(feed_dict={self.input_tensor: cube}) 

314 par_transformed = self.par_pred_transformed.eval( 

315 feed_dict={self.pred: pred} 

316 ) 

317 

318 else: 

319 ind_batch = list(range(0, len(cube), batchsize)) + [len(cube)] 

320 par_transformed = [None] * (len(ind_batch) - 1) 

321 

322 for i in range(len(ind_batch) - 1): 

323 LOGGER.info(f"Predicting on batch {i + 1} / {len(ind_batch) - 1}") 

324 cube_current = cube[ind_batch[i] : ind_batch[i + 1]] 

325 

326 pred = self.pred.eval(feed_dict={self.input_tensor: cube_current}) 

327 par_transformed[i] = self.par_pred_transformed.eval( 

328 feed_dict={self.pred: pred} 

329 ) 

330 

331 par_transformed = np.concatenate(par_transformed) 

332 

333 return par_transformed 

334 

335 

336def get_path_output_config(path_cnn): 

337 return path_cnn + ".h5" 

338 

339 

340def normalize_stamps(stamps): 

341 stamps_min = tf.reduce_min(stamps, axis=(1, 2), keepdims=True) 

342 

343 stamps -= stamps_min 

344 

345 stamps_max = tf.reduce_max(stamps, axis=(1, 2), keepdims=True) 

346 

347 stamps /= stamps_max 

348 

349 return stamps