Coverage for src/edelweiss/regressor.py: 100%
112 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
1# Copyright (C) 2024 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
4# created: Sat Jan 27 2024
7import os
8import pickle
10import joblib
11import numpy as np
12from cosmic_toolbox import arraytools as at
13from cosmic_toolbox import file_utils, logger
14from sklearn.utils.class_weight import compute_sample_weight
16from edelweiss import clf_utils, reg_utils
18LOGGER = logger.get_logger(__file__)
21def load_regressor(path, name="regressor"):
22 """
23 Load a regressor from a given path.
25 :param path: path to the folder containing the regressor
26 :param name: the name of the regressor
27 :return: the loaded regressor
28 """
29 with open(os.path.join(path, name + ".pkl"), "rb") as f:
30 reg = pickle.load(f)
31 reg.pipe = joblib.load(os.path.join(path, "model.pkl"))
32 LOGGER.info(f"Regressor loaded from {path}")
33 return reg
36class Regressor:
37 """
38 Wrapper class for a several regression models.
40 :param scaler: the scaler to use for the regressor
41 :param reg: the regressor to use
42 :param cv: number of cross validation folds, if 0 no cross validation is performed
43 :param cv_scoring: the scoring method to use for cross validation
44 :param input_params: the names of the input parameters
45 :param output_params: the names of the output parameters
46 :param reg_kwargs: additional keyword arguments for the regressor
48 """
50 def __init__(
51 self,
52 scaler="standard",
53 reg="linear",
54 cv=0,
55 cv_scoring="neg_mean_squared_error",
56 input_params=None,
57 output_params=None,
58 **reg_kwargs,
59 ):
60 """
61 Initialize the regressor.
62 """
63 self.scaler = scaler
64 sc = clf_utils.get_scaler(scaler)
65 self.y_scaler = clf_utils.get_scaler(scaler)
66 self.reg = reg
67 self.pipe = reg_utils.get_regressor(reg, sc, **reg_kwargs)
68 self.cv = cv
69 self.cv_scoring = cv_scoring
70 self.reg_kwargs = reg_kwargs
71 self.input_params = input_params
72 self.output_params = output_params
73 self._regressor = None
74 self._scaler = None
75 self.mad = None
76 self.mse = None
77 self.max_error = None
79 def train(self, X, y, flat_param=None, **args):
80 """
81 Train the regressor.
83 :param X: the training data
84 :param y: the training labels
85 """
86 X, y = self._check_if_recarray(X, y)
88 if self.input_params is None:
89 self.input_params = []
90 for i in range(X.shape[1]):
91 self.input_params.append(f"param_{i}")
92 if self.output_params is None:
93 self.output_params = []
94 for i in range(y.shape[1]):
95 self.output_params.append(f"param_{i}")
97 LOGGER.info("Training regressor")
98 LOGGER.info(f"Input parameters: {self.input_params}")
99 LOGGER.info(f"Output parameters: {self.output_params}")
100 LOGGER.info(f"Number of training samples: {X.shape[0]}")
102 if self.cv > 1:
103 LOGGER.error("Cross validation not implemented yet")
105 y = self.y_scaler.fit_transform(y)
107 sample_weight = None
108 if flat_param is not None:
109 flat_param_index = np.where(np.array(self.input_params) == flat_param)[0]
110 if len(flat_param_index) == 0:
111 # flat param is an output param
112 flat_param_index = np.where(np.array(self.output_params) == flat_param)[
113 0
114 ]
115 sample_weight = compute_sample_weight(
116 "balanced", y[:, flat_param_index]
117 )
118 else:
119 # flat param is an input param
120 flat_param_index = flat_param_index[0]
121 sample_weight = compute_sample_weight(
122 "balanced", X[:, flat_param_index]
123 )
124 self.pipe.fit(X, y, reg__sample_weight=sample_weight, **args)
125 else:
126 self.pipe.fit(X, y, **args)
128 fit = train
130 def _predict(self, X):
131 """
132 Predict the output from the input.
134 :param X: the input data
135 :return: the predicted output as an array
136 """
137 X, _ = self._check_if_recarray(X, None)
138 y = np.atleast_2d(self.pipe.predict(X))
139 return self.y_scaler.inverse_transform(y)
141 def predict(self, X):
142 """
143 Predict the output from the input.
145 :param X: the input data
146 :return: the predicted output as a recarray
147 """
148 y = self._predict(X)
149 return at.arr2rec(y, names=self.output_params)
151 __call__ = predict
153 def test(self, X, y):
154 """
155 Test the regressor.
157 :param X: the test data
158 :param y: the test labels
159 """
160 X, y = self._check_if_recarray(X, y)
161 LOGGER.info("Testing regressor")
162 LOGGER.info(f"Number of test samples: {X.shape[0]}")
163 y_pred = self._predict(X)
164 mad = np.mean(np.abs(y - y_pred), axis=0)
165 mse = np.mean((y - y_pred) ** 2, axis=0)
166 max_error = np.max(np.abs(y - y_pred), axis=0)
167 self.mad = at.arr2rec(mad, names=self.output_params)
168 self.mse = at.arr2rec(mse, names=self.output_params)
169 self.max_error = at.arr2rec(max_error, names=self.output_params)
170 LOGGER.info(f"max MAD: {max(mad)}")
171 LOGGER.info(f"max MSE: {max(mse)}")
172 LOGGER.info(f"max Max error: {max(max_error)}")
174 relative_mad = mad / np.mean(np.abs(y), axis=0)
175 relative_mse = mse / np.mean(y**2, axis=0)
176 relative_max_error = max_error / np.max(np.abs(y), axis=0)
177 self.relative_mad = at.arr2rec(relative_mad, names=self.output_params)
178 self.relative_mse = at.arr2rec(relative_mse, names=self.output_params)
179 self.relative_max_error = at.arr2rec(
180 relative_max_error, names=self.output_params
181 )
182 LOGGER.info(f"max relative MAD: {max(relative_mad)}")
183 LOGGER.info(f"max relative MSE: {max(relative_mse)}")
184 LOGGER.info(f"max relative Max error: {max(relative_max_error)}")
186 def save(self, path, name="regressor"):
187 """
188 Save the regressor to a given path.
190 :param path: the path where to save the regressor
191 :param name: the name of the regressor
192 """
193 file_utils.robust_makedirs(path)
194 joblib.dump(self.pipe, os.path.join(path, "model.pkl"))
195 with open(os.path.join(path, name + ".pkl"), "wb") as f:
196 pickle.dump(self, f)
197 LOGGER.info(f"Regressor saved to {path}")
199 def _check_if_recarray(self, X, y=None):
200 try:
201 X, x_names = at.rec2arr(X, return_names=True)
202 if self.input_params is None:
203 self.input_params = x_names
204 if y is not None:
205 y, y_names = at.rec2arr(y, return_names=True)
206 if self.output_params is None:
207 self.output_params = y_names
208 except Exception:
209 # is already a normal array
210 return X, y
211 assert self.input_params == x_names, "Input parameters do not match"
212 if y is not None:
213 assert self.output_params == y_names, "Output parameters do not match"
214 return X, y