Coverage for src/ufig/psf_estimation/cnn_util.py: 96%
124 statements
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-07 15:17 +0000
1# Copyright (C) 2018 ETH Zurich, Institute for Particle Physics and Astrophysics
3"""
4Created on May 25, 2018
5author: Joerg Herbel
6Original Source: ethz_des_mccl.psf_estimation.cnn_psf.cnn_util
7"""
9import h5py
10import numpy as np
11import tensorflow.compat.v1 as tf
12import yaml
13from cosmic_toolbox import logger
15tf.disable_v2_behavior()
17LOGGER = logger.get_logger(__file__)
20def res_layer(tensor_in, kernel_shape, activation):
21 """
22 Standard resnet layer (from tiny image net)
23 """
24 with tf.variable_scope("conv_1"):
25 weights = tf.get_variable(
26 "weights",
27 kernel_shape,
28 initializer=tf.truncated_normal_initializer(0, 0.01),
29 )
30 biases = tf.get_variable(
31 "biases", kernel_shape[-1], initializer=tf.zeros_initializer()
32 )
33 conv = (
34 tf.nn.conv2d(tensor_in, weights, strides=[1, 1, 1, 1], padding="SAME")
35 + biases
36 )
38 # Batch normalization with proper variable creation to match checkpoint
39 with tf.variable_scope("batch_normalization"):
40 # Create variables that match the checkpoint format
41 mean = tf.get_variable(
42 "moving_mean",
43 [kernel_shape[-1]],
44 initializer=tf.zeros_initializer(),
45 trainable=False,
46 )
47 variance = tf.get_variable(
48 "moving_variance",
49 [kernel_shape[-1]],
50 initializer=tf.ones_initializer(),
51 trainable=False,
52 )
53 beta = tf.get_variable(
54 "beta", [kernel_shape[-1]], initializer=tf.zeros_initializer()
55 )
56 gamma = tf.get_variable(
57 "gamma", [kernel_shape[-1]], initializer=tf.ones_initializer()
58 )
60 conv = tf.nn.batch_normalization(
61 conv,
62 mean=mean,
63 variance=variance,
64 offset=beta,
65 scale=gamma,
66 variance_epsilon=1e-3,
67 )
68 conv = activation(conv)
70 with tf.variable_scope("conv_2"):
71 weights = tf.get_variable(
72 "weights",
73 kernel_shape,
74 initializer=tf.truncated_normal_initializer(0, 0.01),
75 )
76 biases = tf.get_variable(
77 "biases",
78 kernel_shape[-1],
79 initializer=tf.truncated_normal_initializer(0, 0.01),
80 )
81 conv = (
82 tf.nn.conv2d(conv, weights, strides=[1, 1, 1, 1], padding="SAME") + biases
83 )
85 # Batch normalization with proper variable creation to match checkpoint
86 with tf.variable_scope("batch_normalization"):
87 # Create variables that match the checkpoint format
88 mean = tf.get_variable(
89 "moving_mean",
90 [kernel_shape[-1]],
91 initializer=tf.zeros_initializer(),
92 trainable=False,
93 )
94 variance = tf.get_variable(
95 "moving_variance",
96 [kernel_shape[-1]],
97 initializer=tf.ones_initializer(),
98 trainable=False,
99 )
100 beta = tf.get_variable(
101 "beta", [kernel_shape[-1]], initializer=tf.zeros_initializer()
102 )
103 gamma = tf.get_variable(
104 "gamma", [kernel_shape[-1]], initializer=tf.ones_initializer()
105 )
107 conv = tf.nn.batch_normalization(
108 conv,
109 mean=mean,
110 variance=variance,
111 offset=beta,
112 scale=gamma,
113 variance_epsilon=1e-3,
114 )
115 conv = activation(conv + tensor_in)
117 return conv
120def create_cnn(
121 input,
122 filter_sizes,
123 n_filters_start,
124 n_resnet_layers,
125 resnet_layers_kernel_size,
126 n_fc,
127 dropout_rate,
128 n_out,
129 activation_function="relu",
130 apply_dropout=True,
131 downsampling_method="max_pool",
132 padding="same",
133):
134 activation = getattr(tf.nn, activation_function)
135 x_tensor = tf.reshape(input, [-1] + input.get_shape().as_list()[1:] + [1])
137 # Convolutional layers
138 current_n_channels = 1
139 current_height, current_width = x_tensor.get_shape().as_list()[1:3]
140 x_conv = x_tensor
142 for layer_ind in range(len(filter_sizes)):
143 current_n_channels = int(n_filters_start * 2**layer_ind)
145 if downsampling_method == "off": 145 ↛ 174line 145 didn't jump to line 174 because the condition on line 145 was always true
146 # Use variable names compatible with tf.layers.conv2d
147 scope_name = "conv2d" if layer_ind == 0 else f"conv2d_{layer_ind}"
149 with tf.variable_scope(scope_name):
150 kernel = tf.get_variable(
151 "kernel",
152 [
153 filter_sizes[layer_ind],
154 filter_sizes[layer_ind],
155 x_conv.get_shape()[-1],
156 current_n_channels,
157 ],
158 initializer=tf.truncated_normal_initializer(0, 0.01),
159 )
160 bias = tf.get_variable(
161 "bias", [current_n_channels], initializer=tf.zeros_initializer()
162 )
163 x_conv = (
164 tf.nn.conv2d(
165 input=x_conv,
166 filters=kernel,
167 strides=[1, 1, 1, 1],
168 padding=padding.upper(),
169 )
170 + bias
171 )
172 x_conv = activation(x_conv)
173 else:
174 raise ValueError(
175 f"Unsupported downsampling method: {downsampling_method}. "
176 "Currently only 'off' is supported within UFig. For other methods, "
177 "refer to the original implementation in the ethz_des_mccl repo."
178 )
180 # subtract in case of valid padding
181 if padding == "valid": 181 ↛ 142line 181 didn't jump to line 142 because the condition on line 181 was always true
182 current_height -= filter_sizes[layer_ind] - 1
183 current_width -= filter_sizes[layer_ind] - 1
185 # ResNet layers
186 resnet_kernel_shape = [
187 resnet_layers_kernel_size,
188 resnet_layers_kernel_size,
189 current_n_channels,
190 current_n_channels,
191 ]
193 for i_res in range(n_resnet_layers):
194 with tf.variable_scope(f"resnet_layer_{i_res + 1}"):
195 x_conv = res_layer(x_conv, resnet_kernel_shape, activation)
197 # Fully connected layers
198 x_conv_flat = tf.reshape(
199 x_conv, [-1, current_height * current_width * current_n_channels]
200 )
202 x_fc = x_conv_flat
203 for fc_ind in range(len(n_fc)):
204 # Use variable names compatible with tf.layers.dense
205 scope_name = "dense" if fc_ind == 0 else f"dense_{fc_ind}"
207 with tf.variable_scope(scope_name):
208 kernel = tf.get_variable(
209 "kernel",
210 [x_fc.get_shape()[-1], n_fc[fc_ind]],
211 initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1),
212 )
213 bias = tf.get_variable(
214 "bias", [n_fc[fc_ind]], initializer=tf.constant_initializer(value=0.1)
215 )
216 x_fc = activation(tf.matmul(x_fc, kernel) + bias)
218 # Dropout
219 x_do = tf.nn.dropout(x_fc, rate=dropout_rate if apply_dropout else 0.0)
221 # Map the fully connected features to output variables
222 # The output layer is just another dense layer in the checkpoint
223 final_dense_ind = len(n_fc)
224 scope_name = "dense" if final_dense_ind == 0 else f"dense_{final_dense_ind}"
226 with tf.variable_scope(scope_name):
227 kernel_out = tf.get_variable(
228 "kernel",
229 [x_fc.get_shape()[-1], n_out],
230 initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1),
231 )
232 bias_out = tf.get_variable(
233 "bias", [n_out], initializer=tf.constant_initializer(value=0.1)
234 )
235 x_out = tf.matmul(x_do, kernel_out) + bias_out
237 return x_out
240class CNNPredictor:
241 def __init__(self, path_trained_cnn):
242 LOGGER.debug(f"tensorflow {str(tf)} version {tf.__version__}")
244 self.path_trained_cnn = path_trained_cnn
246 # Reset graph
247 tf.reset_default_graph()
249 # Load configuration
250 with h5py.File(
251 get_path_output_config(path_trained_cnn), mode="r"
252 ) as fh5_config:
253 self.config = yaml.safe_load(fh5_config.attrs["config"])
254 self.config_training_data = yaml.safe_load(
255 fh5_config.attrs["config_training_data"]
256 )
257 self.input_shape = tuple(fh5_config["input_shape"])
258 self.means = fh5_config["means"][...]
259 self.scales = fh5_config["scales"][...]
261 self.param_names = self.config["param_names"]
263 # For backwards compatibility
264 self.config.setdefault("n_resnet_layers", 0)
265 self.config.setdefault("resnet_layers_kernel_size", 3)
266 self.config.setdefault("activation_function", "relu")
267 self.config.setdefault("padding", "same")
269 # Get loss function used for training
270 if "loss_function_kwargs" in self.config: 270 ↛ 273line 270 didn't jump to line 273 because the condition on line 270 was always true
271 self.config["loss_function_kwargs"]["is_training"] = False
273 n_pred = 2 * len(self.config["param_names"])
274 self._transform_predictions = self._apply_means_scales
276 # Setup network
277 self.input_tensor = tf.placeholder(tf.float32, (None,) + self.input_shape)
278 input_tensor_norm = normalize_stamps(self.input_tensor)
279 self.pred = create_cnn(
280 input=input_tensor_norm,
281 filter_sizes=self.config["filter_sizes"],
282 n_filters_start=self.config["n_filters_start"],
283 n_resnet_layers=self.config["n_resnet_layers"],
284 resnet_layers_kernel_size=self.config["resnet_layers_kernel_size"],
285 n_fc=self.config["n_fully_connected"],
286 dropout_rate=self.config["dropout_rate"],
287 n_out=n_pred,
288 activation_function=self.config["activation_function"],
289 apply_dropout=False,
290 downsampling_method=self.config["downsampling_method"],
291 padding=self.config["padding"],
292 )
294 # Transform predicted parameters
295 self.par_pred_transformed = self._transform_predictions(
296 self.pred[:, : len(self.param_names)]
297 )
299 def _apply_means_scales(self, pred):
300 pred *= self.scales
301 pred += self.means
302 return pred
304 def __call__(self, cube, batchsize=None):
305 with tf.Session() as sess:
306 sess.run(tf.global_variables_initializer())
307 tf.train.Saver().restore(sess, self.path_trained_cnn)
309 if batchsize is None or len(cube) == 0:
310 if len(cube) == 0: 310 ↛ 313line 310 didn't jump to line 313 because the condition on line 310 was always true
311 LOGGER.warning("Predicting on empty cube, output will be empty!")
313 pred = self.pred.eval(feed_dict={self.input_tensor: cube})
314 par_transformed = self.par_pred_transformed.eval(
315 feed_dict={self.pred: pred}
316 )
318 else:
319 ind_batch = list(range(0, len(cube), batchsize)) + [len(cube)]
320 par_transformed = [None] * (len(ind_batch) - 1)
322 for i in range(len(ind_batch) - 1):
323 LOGGER.info(f"Predicting on batch {i + 1} / {len(ind_batch) - 1}")
324 cube_current = cube[ind_batch[i] : ind_batch[i + 1]]
326 pred = self.pred.eval(feed_dict={self.input_tensor: cube_current})
327 par_transformed[i] = self.par_pred_transformed.eval(
328 feed_dict={self.pred: pred}
329 )
331 par_transformed = np.concatenate(par_transformed)
333 return par_transformed
336def get_path_output_config(path_cnn):
337 return path_cnn + ".h5"
340def normalize_stamps(stamps):
341 stamps_min = tf.reduce_min(stamps, axis=(1, 2), keepdims=True)
343 stamps -= stamps_min
345 stamps_max = tf.reduce_max(stamps, axis=(1, 2), keepdims=True)
347 stamps /= stamps_max
349 return stamps