Coverage for src/edelweiss/nflow_utils.py: 100%
79 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
1# Copyright (C) 2023 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
5import numpy as np
6from cosmic_toolbox import arraytools as at
7from cosmic_toolbox import file_utils
8from cosmic_toolbox.logger import get_logger
9from sklearn.preprocessing import (MaxAbsScaler, MinMaxScaler,
10 PowerTransformer, QuantileTransformer,
11 RobustScaler, StandardScaler)
13LOGGER = get_logger(__file__)
16class ModelNotConvergedError(Exception):
17 """
18 Custom error class for when a has not converged.
19 """
21 def __init__(self, model_name, reason=None):
22 """
23 Initialize the custom error.
25 :param model_name: name of the model that did not converge
26 :param reason: reason why the model did not converge
27 """
28 message = f"The {model_name} model did not converge."
29 if reason is not None:
30 message += f" Reason: {reason}"
31 super().__init__(message)
34def check_convergence(losses, min_loss=5):
35 """
36 Check if the model has converged.
38 :param losses: list of losses
39 :param min_loss: minimum loss, if the loss is higher than this,
40 the model has not converged
41 :raises ModelNotConvergedError: if the model has not converged
42 """
44 if np.nanmin(losses) > min_loss:
45 raise ModelNotConvergedError(
46 "normalizing flow", reason=f"loss too high: {np.nanmin(losses)}"
47 )
48 if np.isnan(losses[-1]):
49 raise ModelNotConvergedError("normalizing flow", reason="loss is NaN")
50 if losses[-1] == np.inf:
51 raise ModelNotConvergedError("normalizing flow", reason="loss is inf")
52 if losses[-1] == -np.inf:
53 raise ModelNotConvergedError("normalizing flow", reason="loss is -inf")
56def prepare_columns(args, bands=None):
57 """
58 Prepare the columns for the training of the normalizing flow.
60 :param args: arparse arguments
61 :param bands: list of bands to use, if None, no bands are used
62 :return: input and output columns
63 """
65 conf = file_utils.read_yaml(args.config_path)
66 input = []
67 output = []
68 for par in conf["input_band_dep"]:
69 if bands is not None:
70 for band in bands:
71 input.append(f"{par}_{band}")
72 else:
73 input.append(par)
75 for par in conf["input_band_indep"]:
76 input.append(par)
78 for par in conf["output"]:
79 if bands is not None:
80 for band in bands:
81 output.append(f"{par}_{band}")
82 else:
83 output.append(par)
84 return input, output
87def prepare_data(args, X):
88 """
89 Prepare the data for the training of the normalizing flow by combining the different
90 bands to one array.
92 :param args: argparse arguments
93 :param X: dictionary with the data (keys are the bands)
94 :return: rec array with the data
95 """
96 try:
97 from legacy_abc.analysis.mmd_scalings.mmd_scalings_utils import \
98 add_fluxfrac_col
99 except Exception:
100 LOGGER.warning(
101 "Could not import add_fluxfrac_col from legacy_abc."
102 " Check if the function used here is correct."
103 )
105 def add_fluxfrac_col(catalogs, list_bands):
106 list_flux = []
107 for band in list_bands:
108 list_flux.append(catalogs[band]["FLUX_APER"])
109 flux_total = np.array(list_flux).sum(axis=0)
111 for band in list_bands:
112 if "flux_frac" not in catalogs[band].dtype.names:
113 catalogs[band] = at.add_cols(catalogs[band], names=["flux_frac"])
114 catalogs[band]["flux_frac"] = catalogs[band]["FLUX_APER"] / flux_total
116 conf = file_utils.read_yaml(args.config_path)
118 if "flux_frac" in conf["output"]:
119 add_fluxfrac_col(X, args.bands)
120 data = {}
121 for par in conf["input_band_indep"]:
122 data[par] = X[list(X.keys())[0]][par]
123 for par in conf["input_band_dep"]:
124 for band in args.bands:
125 data[f"{par}_{band}"] = X[band][par]
126 for par in conf["output"]:
127 for band in args.bands:
128 data[f"{par}_{band}"] = X[band][par]
129 return at.dict2rec(data)
132def get_scalers(scaler):
133 """
134 Get the scalers from the name.
136 :param scaler: name of the scaler (str)
137 :return: scaler
138 :raises ValueError: if the scaler is not implemented
139 """
141 if scaler == "robust":
142 return RobustScaler(), RobustScaler()
143 elif scaler == "power":
144 return PowerTransformer(), PowerTransformer()
145 elif scaler == "standard":
146 return StandardScaler(), StandardScaler()
147 elif scaler == "minmax":
148 return MinMaxScaler(), MinMaxScaler()
149 elif scaler == "maxabs":
150 return MaxAbsScaler(), MaxAbsScaler()
151 elif scaler == "quantile":
152 return QuantileTransformer(), QuantileTransformer()
153 else:
154 raise ValueError(f"Scaler {scaler} not implemented yet.")