Coverage for src/edelweiss/nflow.py: 100%
101 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
1# Copyright (C) 2023 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
5import os
6import pickle
8import joblib
9import numpy as np
10import pandas as pd
11from cosmic_toolbox import file_utils, logger
12from pzflow import Flow
14from edelweiss import nflow_utils
16LOGGER = logger.get_logger(__file__)
19def load_nflow(path, band=None, subfolder=None):
20 """
21 Load a normalizing flow from a given path.
23 :param path: path to the folder containing the emulator
24 :param band: the band to load (if None, assumes that there is only one nflow)
25 :param subfolder: subfolder of the emulator folder where the normalizing flow is
26 stored
27 :return: the loaded normalizing flow
28 """
29 if subfolder is None:
30 subfolder = "nflow"
31 if band is not None:
32 subfolder = subfolder + "_" + band
33 output_directory = os.path.join(path, subfolder)
34 with open(os.path.join(output_directory, "nflow.pkl"), "rb") as f:
35 nflow = pickle.load(f)
36 nflow.flow = Flow(file=os.path.join(output_directory, "model.pkl"))
37 nflow.scaler_input = joblib.load(
38 os.path.join(output_directory, "scalers/flow_scaler_input.pkl")
39 )
40 nflow.scaler_output = joblib.load(
41 os.path.join(output_directory, "scalers/flow_scaler_output.pkl")
42 )
43 return nflow
46class Nflow:
47 """
48 The normalizing flow class that wraps a pzflow normalizing flow.
50 :param output: the names of the output parameters
51 :param input: the names of the input parameters (=conditional parameters)
52 :param scaler: the scaler to use for the normalizing flow
53 """
55 def __init__(self, output=None, input=None, scaler="standard"):
56 """
57 Initialize the normalizing flow.
58 """
59 if isinstance(input, tuple):
60 input = np.array(input)
61 if isinstance(output, tuple):
62 output = np.array(output)
63 self.input = input
64 self.output = output
65 if input is None:
66 input = []
67 if output is None:
68 output = []
69 self.all_params = np.concatenate([input, output])
70 self.scaler = scaler
71 self.scaler_input, self.scaler_output = nflow_utils.get_scalers(scaler)
73 def train(
74 self,
75 X,
76 epochs=100,
77 batch_size=1024,
78 progress_bar=True,
79 verbose=True,
80 min_loss=5,
81 ):
82 """
83 Train the normalizing flow.
85 :param X: the features to train on (recarray)
86 :param epochs: number of epochs
87 :param batch_size: batch size
88 :param progress_bar: whether to show a progress bar
89 :param verbose: whether to print the losses
90 :param min_loss: minimum loss that is allowed for convergence
91 """
92 self.epochs = epochs
93 self.batch_size = batch_size
94 LOGGER.info("==============================")
95 LOGGER.info("Training normalizing flow with")
96 LOGGER.info(f"{len(X)} samples and")
97 LOGGER.info(f"conditional parameters: {self.input}")
98 LOGGER.info(f"other parameters: {self.output}")
99 LOGGER.info("==============================")
101 X = pd.DataFrame(X)
102 if self.input is not None:
103 X[self.input] = self.scaler_input.fit_transform(X[self.input])
104 if self.output is None:
105 self.output = X.columns
106 self.all_params = self.output
107 X[self.output] = self.scaler_output.fit_transform(X[self.output])
109 self.flow = Flow(data_columns=self.output, conditional_columns=self.input)
111 self.losses = self.flow.train(
112 X,
113 epochs=epochs,
114 batch_size=batch_size,
115 progress_bar=progress_bar,
116 verbose=verbose,
117 )
118 nflow_utils.check_convergence(self.losses, min_loss=min_loss)
119 LOGGER.info(
120 "Training completed with best loss at"
121 f" epoch {np.argmin(self.losses)}/{self.epochs}"
122 f" with loss {np.min(self.losses):.2f}"
123 )
125 fit = train
127 def sample(self, X=None, n_samples=1):
128 """
129 Sample from the normalizing flow.
131 :param X: the features to sample from (recarray or None for non-conditional
132 sampling)
133 :param n_samples: number of samples to draw, number of total samples is
134 n_samples * len(X)
135 :return: the sampled features (including the conditional parameters)
136 """
138 if X is not None:
139 params = X.dtype.names
140 assert np.all(
141 list(params) == self.input
142 ), "Input parameters do not match the trained parameters"
144 X = pd.DataFrame(X)
145 X[self.input] = self.scaler_input.transform(X[self.input])
147 f = self.flow.sample(n_samples, conditions=X)
148 f = f.reindex(columns=self.all_params)
150 # Find NaNs and replace with mean
151 f.replace([np.inf, -np.inf], np.nan, inplace=True)
152 nan_inf_mask = f.isna()
153 n_nans = f.isna().sum().sum()
154 if n_nans > 0:
155 LOGGER.warning(f"Found {n_nans} NaNs or infs in the sampled data")
156 column_means = f.mean()
157 f.fillna(column_means, inplace=True)
159 # Inverse transform
160 if self.input is not None:
161 f[self.input] = self.scaler_input.inverse_transform(f[self.input])
162 f[self.output] = self.scaler_output.inverse_transform(f[self.output])
164 # Reintroduce NaNs
165 f[nan_inf_mask] = np.nan
166 return f.to_records(index=False)
168 __call__ = sample
170 def save(self, path, band=None, subfolder=None):
171 """
172 Save the normalizing flow to a given path.
174 :param path: path to the folder where the emulator is saved
175 :param subfolder: subfolder of the emulator folder where the normalizing flow is
176 stored
177 """
178 if subfolder is None:
179 subfolder = "nflow"
180 if band is not None:
181 subfolder = subfolder + "_" + band
182 output_directory = os.path.join(path, subfolder)
183 file_utils.robust_makedirs(output_directory)
185 flow_path = os.path.join(output_directory, "model.pkl")
186 self.flow.save(flow_path)
187 self.flow = None
188 LOGGER.debug(f"Flow saved to {flow_path}")
190 scaler_path = os.path.join(output_directory, "scalers")
191 file_utils.robust_makedirs(scaler_path)
192 scaler_input_path = os.path.join(scaler_path, "flow_scaler_input.pkl")
193 scaler_output_path = os.path.join(scaler_path, "flow_scaler_output.pkl")
194 joblib.dump(self.scaler_input, scaler_input_path)
195 joblib.dump(self.scaler_output, scaler_output_path)
196 self.scaler_input = None
197 self.scaler_output = None
198 LOGGER.debug(f"Scalers saved to {scaler_path}")
199 with open(os.path.join(output_directory, "nflow.pkl"), "wb") as f:
200 pickle.dump(self, f)
201 LOGGER.info(f"Normalizing flow saved to {output_directory}")