Coverage for src/edelweiss/nflow_utils.py: 100%

79 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-18 17:09 +0000

1# Copyright (C) 2023 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4 

5import numpy as np 

6from cosmic_toolbox import arraytools as at 

7from cosmic_toolbox import file_utils 

8from cosmic_toolbox.logger import get_logger 

9from sklearn.preprocessing import (MaxAbsScaler, MinMaxScaler, 

10 PowerTransformer, QuantileTransformer, 

11 RobustScaler, StandardScaler) 

12 

13LOGGER = get_logger(__file__) 

14 

15 

16class ModelNotConvergedError(Exception): 

17 """ 

18 Custom error class for when a has not converged. 

19 """ 

20 

21 def __init__(self, model_name, reason=None): 

22 """ 

23 Initialize the custom error. 

24 

25 :param model_name: name of the model that did not converge 

26 :param reason: reason why the model did not converge 

27 """ 

28 message = f"The {model_name} model did not converge." 

29 if reason is not None: 

30 message += f" Reason: {reason}" 

31 super().__init__(message) 

32 

33 

34def check_convergence(losses, min_loss=5): 

35 """ 

36 Check if the model has converged. 

37 

38 :param losses: list of losses 

39 :param min_loss: minimum loss, if the loss is higher than this, 

40 the model has not converged 

41 :raises ModelNotConvergedError: if the model has not converged 

42 """ 

43 

44 if np.nanmin(losses) > min_loss: 

45 raise ModelNotConvergedError( 

46 "normalizing flow", reason=f"loss too high: {np.nanmin(losses)}" 

47 ) 

48 if np.isnan(losses[-1]): 

49 raise ModelNotConvergedError("normalizing flow", reason="loss is NaN") 

50 if losses[-1] == np.inf: 

51 raise ModelNotConvergedError("normalizing flow", reason="loss is inf") 

52 if losses[-1] == -np.inf: 

53 raise ModelNotConvergedError("normalizing flow", reason="loss is -inf") 

54 

55 

56def prepare_columns(args, bands=None): 

57 """ 

58 Prepare the columns for the training of the normalizing flow. 

59 

60 :param args: arparse arguments 

61 :param bands: list of bands to use, if None, no bands are used 

62 :return: input and output columns 

63 """ 

64 

65 conf = file_utils.read_yaml(args.config_path) 

66 input = [] 

67 output = [] 

68 for par in conf["input_band_dep"]: 

69 if bands is not None: 

70 for band in bands: 

71 input.append(f"{par}_{band}") 

72 else: 

73 input.append(par) 

74 

75 for par in conf["input_band_indep"]: 

76 input.append(par) 

77 

78 for par in conf["output"]: 

79 if bands is not None: 

80 for band in bands: 

81 output.append(f"{par}_{band}") 

82 else: 

83 output.append(par) 

84 return input, output 

85 

86 

87def prepare_data(args, X): 

88 """ 

89 Prepare the data for the training of the normalizing flow by combining the different 

90 bands to one array. 

91 

92 :param args: argparse arguments 

93 :param X: dictionary with the data (keys are the bands) 

94 :return: rec array with the data 

95 """ 

96 try: 

97 from legacy_abc.analysis.mmd_scalings.mmd_scalings_utils import \ 

98 add_fluxfrac_col 

99 except Exception: 

100 LOGGER.warning( 

101 "Could not import add_fluxfrac_col from legacy_abc." 

102 " Check if the function used here is correct." 

103 ) 

104 

105 def add_fluxfrac_col(catalogs, list_bands): 

106 list_flux = [] 

107 for band in list_bands: 

108 list_flux.append(catalogs[band]["FLUX_APER"]) 

109 flux_total = np.array(list_flux).sum(axis=0) 

110 

111 for band in list_bands: 

112 if "flux_frac" not in catalogs[band].dtype.names: 

113 catalogs[band] = at.add_cols(catalogs[band], names=["flux_frac"]) 

114 catalogs[band]["flux_frac"] = catalogs[band]["FLUX_APER"] / flux_total 

115 

116 conf = file_utils.read_yaml(args.config_path) 

117 

118 if "flux_frac" in conf["output"]: 

119 add_fluxfrac_col(X, args.bands) 

120 data = {} 

121 for par in conf["input_band_indep"]: 

122 data[par] = X[list(X.keys())[0]][par] 

123 for par in conf["input_band_dep"]: 

124 for band in args.bands: 

125 data[f"{par}_{band}"] = X[band][par] 

126 for par in conf["output"]: 

127 for band in args.bands: 

128 data[f"{par}_{band}"] = X[band][par] 

129 return at.dict2rec(data) 

130 

131 

132def get_scalers(scaler): 

133 """ 

134 Get the scalers from the name. 

135 

136 :param scaler: name of the scaler (str) 

137 :return: scaler 

138 :raises ValueError: if the scaler is not implemented 

139 """ 

140 

141 if scaler == "robust": 

142 return RobustScaler(), RobustScaler() 

143 elif scaler == "power": 

144 return PowerTransformer(), PowerTransformer() 

145 elif scaler == "standard": 

146 return StandardScaler(), StandardScaler() 

147 elif scaler == "minmax": 

148 return MinMaxScaler(), MinMaxScaler() 

149 elif scaler == "maxabs": 

150 return MaxAbsScaler(), MaxAbsScaler() 

151 elif scaler == "quantile": 

152 return QuantileTransformer(), QuantileTransformer() 

153 else: 

154 raise ValueError(f"Scaler {scaler} not implemented yet.")