Coverage for src/edelweiss/nflow.py: 100%

101 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-18 17:09 +0000

1# Copyright (C) 2023 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4 

5import os 

6import pickle 

7 

8import joblib 

9import numpy as np 

10import pandas as pd 

11from cosmic_toolbox import file_utils, logger 

12from pzflow import Flow 

13 

14from edelweiss import nflow_utils 

15 

16LOGGER = logger.get_logger(__file__) 

17 

18 

19def load_nflow(path, band=None, subfolder=None): 

20 """ 

21 Load a normalizing flow from a given path. 

22 

23 :param path: path to the folder containing the emulator 

24 :param band: the band to load (if None, assumes that there is only one nflow) 

25 :param subfolder: subfolder of the emulator folder where the normalizing flow is 

26 stored 

27 :return: the loaded normalizing flow 

28 """ 

29 if subfolder is None: 

30 subfolder = "nflow" 

31 if band is not None: 

32 subfolder = subfolder + "_" + band 

33 output_directory = os.path.join(path, subfolder) 

34 with open(os.path.join(output_directory, "nflow.pkl"), "rb") as f: 

35 nflow = pickle.load(f) 

36 nflow.flow = Flow(file=os.path.join(output_directory, "model.pkl")) 

37 nflow.scaler_input = joblib.load( 

38 os.path.join(output_directory, "scalers/flow_scaler_input.pkl") 

39 ) 

40 nflow.scaler_output = joblib.load( 

41 os.path.join(output_directory, "scalers/flow_scaler_output.pkl") 

42 ) 

43 return nflow 

44 

45 

46class Nflow: 

47 """ 

48 The normalizing flow class that wraps a pzflow normalizing flow. 

49 

50 :param output: the names of the output parameters 

51 :param input: the names of the input parameters (=conditional parameters) 

52 :param scaler: the scaler to use for the normalizing flow 

53 """ 

54 

55 def __init__(self, output=None, input=None, scaler="standard"): 

56 """ 

57 Initialize the normalizing flow. 

58 """ 

59 if isinstance(input, tuple): 

60 input = np.array(input) 

61 if isinstance(output, tuple): 

62 output = np.array(output) 

63 self.input = input 

64 self.output = output 

65 if input is None: 

66 input = [] 

67 if output is None: 

68 output = [] 

69 self.all_params = np.concatenate([input, output]) 

70 self.scaler = scaler 

71 self.scaler_input, self.scaler_output = nflow_utils.get_scalers(scaler) 

72 

73 def train( 

74 self, 

75 X, 

76 epochs=100, 

77 batch_size=1024, 

78 progress_bar=True, 

79 verbose=True, 

80 min_loss=5, 

81 ): 

82 """ 

83 Train the normalizing flow. 

84 

85 :param X: the features to train on (recarray) 

86 :param epochs: number of epochs 

87 :param batch_size: batch size 

88 :param progress_bar: whether to show a progress bar 

89 :param verbose: whether to print the losses 

90 :param min_loss: minimum loss that is allowed for convergence 

91 """ 

92 self.epochs = epochs 

93 self.batch_size = batch_size 

94 LOGGER.info("==============================") 

95 LOGGER.info("Training normalizing flow with") 

96 LOGGER.info(f"{len(X)} samples and") 

97 LOGGER.info(f"conditional parameters: {self.input}") 

98 LOGGER.info(f"other parameters: {self.output}") 

99 LOGGER.info("==============================") 

100 

101 X = pd.DataFrame(X) 

102 if self.input is not None: 

103 X[self.input] = self.scaler_input.fit_transform(X[self.input]) 

104 if self.output is None: 

105 self.output = X.columns 

106 self.all_params = self.output 

107 X[self.output] = self.scaler_output.fit_transform(X[self.output]) 

108 

109 self.flow = Flow(data_columns=self.output, conditional_columns=self.input) 

110 

111 self.losses = self.flow.train( 

112 X, 

113 epochs=epochs, 

114 batch_size=batch_size, 

115 progress_bar=progress_bar, 

116 verbose=verbose, 

117 ) 

118 nflow_utils.check_convergence(self.losses, min_loss=min_loss) 

119 LOGGER.info( 

120 "Training completed with best loss at" 

121 f" epoch {np.argmin(self.losses)}/{self.epochs}" 

122 f" with loss {np.min(self.losses):.2f}" 

123 ) 

124 

125 fit = train 

126 

127 def sample(self, X=None, n_samples=1): 

128 """ 

129 Sample from the normalizing flow. 

130 

131 :param X: the features to sample from (recarray or None for non-conditional 

132 sampling) 

133 :param n_samples: number of samples to draw, number of total samples is 

134 n_samples * len(X) 

135 :return: the sampled features (including the conditional parameters) 

136 """ 

137 

138 if X is not None: 

139 params = X.dtype.names 

140 assert np.all( 

141 list(params) == self.input 

142 ), "Input parameters do not match the trained parameters" 

143 

144 X = pd.DataFrame(X) 

145 X[self.input] = self.scaler_input.transform(X[self.input]) 

146 

147 f = self.flow.sample(n_samples, conditions=X) 

148 f = f.reindex(columns=self.all_params) 

149 

150 # Find NaNs and replace with mean 

151 f.replace([np.inf, -np.inf], np.nan, inplace=True) 

152 nan_inf_mask = f.isna() 

153 n_nans = f.isna().sum().sum() 

154 if n_nans > 0: 

155 LOGGER.warning(f"Found {n_nans} NaNs or infs in the sampled data") 

156 column_means = f.mean() 

157 f.fillna(column_means, inplace=True) 

158 

159 # Inverse transform 

160 if self.input is not None: 

161 f[self.input] = self.scaler_input.inverse_transform(f[self.input]) 

162 f[self.output] = self.scaler_output.inverse_transform(f[self.output]) 

163 

164 # Reintroduce NaNs 

165 f[nan_inf_mask] = np.nan 

166 return f.to_records(index=False) 

167 

168 __call__ = sample 

169 

170 def save(self, path, band=None, subfolder=None): 

171 """ 

172 Save the normalizing flow to a given path. 

173 

174 :param path: path to the folder where the emulator is saved 

175 :param subfolder: subfolder of the emulator folder where the normalizing flow is 

176 stored 

177 """ 

178 if subfolder is None: 

179 subfolder = "nflow" 

180 if band is not None: 

181 subfolder = subfolder + "_" + band 

182 output_directory = os.path.join(path, subfolder) 

183 file_utils.robust_makedirs(output_directory) 

184 

185 flow_path = os.path.join(output_directory, "model.pkl") 

186 self.flow.save(flow_path) 

187 self.flow = None 

188 LOGGER.debug(f"Flow saved to {flow_path}") 

189 

190 scaler_path = os.path.join(output_directory, "scalers") 

191 file_utils.robust_makedirs(scaler_path) 

192 scaler_input_path = os.path.join(scaler_path, "flow_scaler_input.pkl") 

193 scaler_output_path = os.path.join(scaler_path, "flow_scaler_output.pkl") 

194 joblib.dump(self.scaler_input, scaler_input_path) 

195 joblib.dump(self.scaler_output, scaler_output_path) 

196 self.scaler_input = None 

197 self.scaler_output = None 

198 LOGGER.debug(f"Scalers saved to {scaler_path}") 

199 with open(os.path.join(output_directory, "nflow.pkl"), "wb") as f: 

200 pickle.dump(self, f) 

201 LOGGER.info(f"Normalizing flow saved to {output_directory}")