Coverage for src/edelweiss/classifier.py: 99%
237 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
1# Copyright (C) 2023 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
5import os
6import pickle
8import joblib
9import numpy as np
10from cosmic_toolbox import arraytools as at
11from cosmic_toolbox import file_utils, logger
12from sklearn.calibration import CalibratedClassifierCV
13from sklearn.model_selection import GridSearchCV
15from edelweiss import clf_diagnostics, clf_utils
17LOGGER = logger.get_logger(__file__)
20def load_classifier(path, subfolder=None):
21 """
22 Load a classifier from a given path.
24 :param path: path to the folder containing the emulator
25 :param subfolder: subfolder of the emulator folder where the classifier is stored
26 :return: the loaded classifier
27 """
28 if subfolder is None:
29 subfolder = "clf"
30 output_directory = os.path.join(path, subfolder)
31 with open(os.path.join(output_directory, "clf.pkl"), "rb") as f:
32 clf = pickle.load(f)
33 clf.pipe = joblib.load(os.path.join(output_directory, "model.pkl"))
34 LOGGER.debug(f"Classifier loaded from {output_directory}")
35 return clf
38class Classifier:
39 """
40 The detection classifer class that wraps a sklearn classifier.
42 :param scaler: the scaler to use for the classifier, options: standard, minmax,
43 maxabs, robust, quantile
44 :param clf: the classifier to use, options are: XGB, MLP, RandomForest,
45 NeuralNetwork, LogisticRegression, LinearSVC, DecisionTree, AdaBoost,
46 GaussianNB, QDA, KNN,
47 :param calibrate: whether to calibrate the probabilities
48 :param cv: number of cross validation folds, if 0 no cross validation is performed
49 :param cv_scoring: the scoring method to use for cross validation
50 :param params: the names of the parameters
51 :param clf_kwargs: additional keyword arguments for the classifier
52 """
54 def __init__(
55 self,
56 scaler="standard",
57 clf="XGB",
58 calibrate=True,
59 cv=0,
60 cv_scoring="f1",
61 params=None,
62 **clf_kwargs,
63 ):
64 """
65 Initialize the classifier.
66 """
67 self.scaler = scaler
68 self.clf = clf
69 sc = clf_utils.get_scaler(scaler)
70 self.pipe = clf_utils.get_classifier(clf, sc, **clf_kwargs)
71 self.calibrate = calibrate
72 self.cv = cv
73 self.cv_scoring = cv_scoring
74 self.params = params
75 self.test_scores = None
77 def train(self, X, y, **args):
78 """
79 Train the classifier.
81 :param X: the features to train on (array or recarray)
82 :param y: the labels to train on
83 :param args: additional arguments for the classifier
84 """
85 X = self._check_if_recarray(X)
87 if self.params is None:
88 self.params = np.arange(X.shape[1])
89 LOGGER.warning("No parameter names provided, numbers are used instead")
90 else:
91 assert len(self.params) == X.shape[1], (
92 "Number of parameters in training data does not match number"
93 " of parameters provided before"
94 )
96 LOGGER.info("Training this model:")
97 if self.calibrate:
98 LOGGER.info("CalibratedClassifierCV")
99 clf_names = self.pipe.named_steps.items()
100 for name, estimator in clf_names:
101 LOGGER.info(f"{name}:")
102 LOGGER.info(estimator)
103 LOGGER.info(f"number of samples: {X.shape[0]}")
104 LOGGER.info("-------------------")
106 # tune hyperparameters with grid search
107 if self.cv > 1:
108 LOGGER.info("Start cross validation")
109 param_grid = clf_utils.load_hyperparams(self.pipe["clf"])
110 scorer = clf_utils.get_scorer(self.cv_scoring)
112 # Set up the grid search
113 if "SLURM_CPUS_PER_TASK" in os.environ: # pragma: no cover
114 n_jobs = max(int(os.environ["SLURM_CPUS_PER_TASK"]) // 4, 1)
115 else:
116 n_jobs = 1
117 LOGGER.info(f"Running the Grid search on {n_jobs} jobs")
118 self.pipe = GridSearchCV(
119 estimator=self.pipe,
120 param_grid=param_grid,
121 scoring=scorer,
122 cv=self.cv,
123 n_jobs=n_jobs,
124 )
126 if self.calibrate:
127 self.pipe = CalibratedClassifierCV(
128 self.pipe, cv=self.cv, method="isotonic"
129 )
131 # Run the grid search
132 self.pipe.fit(X, y, **args)
134 if self.calibrate:
135 best_params = self.pipe.calibrated_classifiers_[
136 0
137 ].estimator.best_params_
138 else:
139 best_params = self.pipe.best_params_
140 LOGGER.info("Best parameters found by grid search: %s", best_params)
141 else:
142 if self.calibrate:
143 self.pipe = CalibratedClassifierCV(self.pipe, cv=2, method="isotonic")
144 self.pipe.fit(X, y, **args)
145 self._get_feature_importance()
146 self._get_summed_feature_importance()
147 LOGGER.info("Training completed")
149 fit = train
151 def predict(self, X, prob_multiplier=1.0):
152 """
153 Predict the labels for a given set of features.
155 :param X: the features to predict on (array or recarry)
156 :return: the predicted labels
157 """
158 X = self._check_if_recarray(X)
159 y_prob = self.pipe.predict_proba(X)[:, 1] * prob_multiplier
160 y_prob = np.clip(y_prob, 0, 1)
161 y_pred = y_prob > np.random.rand(len(y_prob))
162 return y_pred
164 def predict_proba(self, X):
165 """
166 Predict the probabilities for a given set of features.
168 :param X: the features to predict on (array or recarry)
169 :return: the predicted probabilities
170 """
171 X = self._check_if_recarray(X)
172 y_prob = self.pipe.predict_proba(X)[:, 1]
173 return y_prob
175 def predict_non_proba(self, X):
176 """
177 Predict the probabilities for a given set of features.
179 :param X: the features to predict on (array or recarry)
180 :return: the predicted probabilities
181 """
182 X = self._check_if_recarray(X)
183 y_pred = self.pipe.predict(X)
184 return y_pred.astype(bool)
186 __call__ = predict
188 def save(self, path, subfolder=None):
189 """
190 Save the classifier to a given path.
192 :param path: path to the folder where the emulator is saved
193 :param subfolder: subfolder of the emulator folder where the classifier is
194 stored
195 """
197 if subfolder is None:
198 subfolder = "clf"
199 output_directory = os.path.join(path, subfolder)
200 file_utils.robust_makedirs(output_directory)
201 joblib.dump(self.pipe, os.path.join(output_directory, "model.pkl"))
202 self.pipe = None
203 with open(os.path.join(output_directory, "clf.pkl"), "wb") as f:
204 pickle.dump(self, f)
205 LOGGER.info(f"Classifier saved to {output_directory}")
207 def test(self, X_test, y_test, non_proba=False):
208 """
209 Tests the classifier on the test data
211 :param test_arr: dict where the test scores will be saved
212 :param clf: classifier
213 :param X_test: test data
214 :param y_test: test labels
215 :param non_proba: whether to use non-probabilistic predictions
216 """
218 # get probability of being detected
219 y_prob = self.predict_proba(X_test)
220 y_pred = self.predict_non_proba(X_test) if non_proba else self.predict(X_test)
221 test_arr = clf_diagnostics.setup_test()
222 clf_diagnostics.get_all_scores(test_arr, y_test, y_pred, y_prob)
223 test_arr = at.dict2rec(test_arr)
224 self.test_scores = test_arr
226 def _check_if_recarray(self, X):
227 try:
228 X, names = at.rec2arr(X, return_names=True)
229 if self.params is None:
230 self.params = names
231 else:
232 assert np.all(
233 names == self.params
234 ), "Input parameters do not match the trained parameters"
235 return X
236 except Exception:
237 return X
239 def _get_feature_importance(self):
240 try:
241 # Try to get the feature importances if clf is GridSearchCV
242 importances = self.pipe.best_estimator_["clf"].feature_importances_
243 self.feature_importances = at.arr2rec(importances, self.params)
244 return
245 except Exception:
246 pass
247 try:
248 # Try to get the feature importances if clf is CalibratedClassifierCV
249 importances = self.pipe.calibrated_classifiers_[
250 0
251 ].estimator._final_estimator.feature_importances_
252 self.feature_importances = at.arr2rec(importances, self.params)
253 return
254 except Exception:
255 try:
256 # Try to get the feature importances if clf is CalibratedClassifierCV
257 # and GridSearchCV
258 importances = (
259 self.pipe.calibrated_classifiers_[0]
260 .estimator.best_estimator_.named_steps["clf"]
261 .feature_importances_
262 )
263 self.feature_importances = at.arr2rec(importances, self.params)
264 return
265 except Exception:
266 pass
267 try:
268 # Try to get the feature importances if clf is not GridSearchCV
269 importances = self.pipe["clf"].feature_importances_
270 self.feature_importances = at.arr2rec(importances, self.params)
271 return
272 except Exception:
273 self.feature_importances = None
275 def _get_summed_feature_importance(self):
276 """
277 Sum the feature importances of the same parameter across different bands.
278 Should be run after _get_feature_importance.
279 """
280 feature_importances = self.feature_importances
281 if feature_importances is None:
282 self.summed_feature_importances = None
283 return
285 # Create a dictionary to store the summed values based on modified prefixes
286 summed_features = {}
288 # Iterate through the feature importances, remove last character if suffix
289 # length is 1
290 for key in feature_importances.dtype.names:
291 parts = key.split("_")
292 prefix = (
293 "_".join(parts[:-1]) if len(parts[-1]) == 1 and len(parts) > 1 else key
294 ) # Remove the last character if suffix length is 1
296 if prefix not in summed_features:
297 summed_features[prefix] = 0.0
299 summed_features[prefix] += feature_importances[key]
301 self.summed_feature_importances = at.dict2rec(summed_features)
304def load_multiclassifier(path, subfolder=None):
305 """
306 Load a multiclassifier from a given path.
308 :param path: path to the folder containing the emulator
309 :param subfolder: subfolder of the emulator folder where the classifier is stored
310 :return: the loaded classifier
311 """
312 if subfolder is None:
313 subfolder = "clf"
314 output_directory = os.path.join(path, subfolder)
315 with open(os.path.join(output_directory, "clf.pkl"), "rb") as f:
316 clf = pickle.load(f)
317 clf.pipe = []
318 for label in clf.labels:
319 clf.pipe.append(load_classifier(path, subfolder=f"{subfolder}_{label}"))
320 LOGGER.debug(f"Classifier loaded from {output_directory}")
321 return clf
324class MultiClassifier:
325 """
326 A classifier class that trains multiple classifiers for a specific label. This label
327 could e.g. be the galaxy type (star, red galaxy, blue galaxy).
329 :param split_label: the label to split the data in different classifers
330 :param labels: the different labels of the split label
331 :param scaler: the scaler to use for the classifier
332 :param clf: the classifier to use
333 :param calibrate: whether to calibrate the probabilities
334 :param cv: number of cross validation folds, if 0 no cross validation is performed
335 :param cv_scoring: the scoring method to use for cross validation
336 :param params: the names of the parameters
337 :param clf_kwargs: additional keyword arguments for the classifier
338 """
340 def __init__(
341 self,
342 split_label="galaxy_type",
343 labels=None,
344 scaler="standard",
345 clf="XGB",
346 calibrate=True,
347 cv=0,
348 cv_scoring="f1",
349 params=None,
350 **clf_kwargs,
351 ):
352 """
353 Initialize the classifier.
354 """
355 if labels is None:
356 labels = [-1, 0, 1]
357 self.split_label = split_label
358 self.labels = labels
359 self.scaler = scaler
360 self.clf = clf
361 self.pipe = [
362 Classifier(
363 scaler=scaler,
364 clf=clf,
365 calibrate=calibrate,
366 cv=cv,
367 cv_scoring=cv_scoring,
368 params=params,
369 **clf_kwargs,
370 )
371 for _ in self.labels
372 ]
374 def train(self, X, y):
375 """
376 Train the classifier.
377 """
378 # TODO: dirty hack, fix this
379 self.params = X.dtype.names
380 for i, label in enumerate(self.labels):
381 idx = X[self.split_label] == label
382 X_ = at.delete_cols(X[idx], self.split_label)
383 self.pipe[i].train(X_, y[idx])
385 fit = train
387 def predict(self, X):
388 """
389 Predict the labels for a given set of features.
390 """
391 y_pred = np.zeros(len(X), dtype=bool)
392 for i, label in enumerate(self.labels):
393 idx = X[self.split_label] == label
394 if np.sum(idx) == 0:
395 continue
396 X_ = at.delete_cols(X[idx], self.split_label)
397 y_pred[idx] = self.pipe[i].predict(X_)
398 return y_pred
400 def predict_proba(self, X):
401 """
402 Predict the probabilities for a given set of features.
403 """
404 y_prob = np.zeros(len(X), dtype=float)
405 for i, label in enumerate(self.labels):
406 idx = X[self.split_label] == label
407 if np.sum(idx) == 0:
408 continue
409 y_prob[idx] = self.pipe[i].predict_proba(X[idx])
410 return y_prob
412 def predict_non_proba(self, X):
413 """
414 Predict the probabilities for a given set of features.
415 """
416 y_pred = np.zeros(len(X), dtype=bool)
417 for i, label in enumerate(self.labels):
418 idx = X[self.split_label] == label
419 if np.sum(idx) == 0:
420 continue
421 X_ = at.delete_cols(X[idx], self.split_label)
422 y_pred[idx] = self.pipe[i].predict_non_proba(X_)
423 return y_pred
425 __call__ = predict
427 def save(self, path, subfolder=None):
428 """
429 Save the classifier to a given path.
431 :param path: path to the folder where the emulator is saved
432 :param subfolder: subfolder of the emulator folder where the classifier is
433 stored
434 """
435 if subfolder is None:
436 subfolder = "clf"
437 output_directory = os.path.join(path, subfolder)
438 file_utils.robust_makedirs(output_directory)
439 for i, label in enumerate(self.labels):
440 self.pipe[i].save(path, subfolder=f"{subfolder}_{label}")
441 self.pipe = None
442 with open(os.path.join(path, subfolder, "clf.pkl"), "wb") as f:
443 pickle.dump(self, f)
444 LOGGER.info(f"MultiClassifier saved to {os.path.join(path, subfolder)}")
446 def test(self, X_test, y_test, non_proba=False):
447 """
448 Tests the classifier on the test data
450 :param test_arr: dict where the test scores will be saved
451 :param clf: classifier
452 :param X_test: test data
453 :param y_test: test labels
454 :param non_proba: whether to use non-probabilistic predictions
455 """
456 # get probability of being detected
457 y_prob = self.predict_proba(X_test)
458 y_pred = self.predict_non_proba(X_test) if non_proba else self.predict(X_test)
459 test_arr = clf_diagnostics.setup_test()
460 clf_diagnostics.get_all_scores(test_arr, y_test, y_pred, y_prob)
461 test_arr = at.dict2rec(test_arr)
462 self.test_scores = test_arr
465class MultiClassClassifier(Classifier):
466 """
467 The detection classifer class that wraps a sklearn classifier for multiple classes.
469 :param scaler: the scaler to use for the classifier, options: standard, minmax,
470 maxabs, robust, quantile
471 :param clf: the classifier to use, options are: XGB, MLP, RandomForest,
472 NeuralNetwork, LogisticRegression, LinearSVC, DecisionTree, AdaBoost,
473 GaussianNB, QDA, KNN,
474 :param calibrate: whether to calibrate the probabilities
475 :param cv: number of cross validation folds, if 0 no cross validation is performed
476 :param cv_scoring: the scoring method to use for cross validation
477 :param params: the names of the parameters
478 :param clf_kwargs: additional keyword arguments for the classifier
479 """
481 def predict(self, X):
482 """
483 Predict the labels for a given set of features.
485 :param X: the features to predict on (array or recarry)
486 :return: the predicted labels
487 """
488 X = self._check_if_recarray(X)
489 y_prob = self.pipe.predict_proba(X)
490 y_pred = np.array([np.random.choice(len(prob), p=prob) for prob in y_prob])
492 return y_pred
494 def predict_proba(self, X):
495 """
496 Predict the probabilities for a given set of features.
498 :param X: the features to predict on (array or recarry)
499 :return: the predicted probabilities
500 """
501 X = self._check_if_recarray(X)
502 y_prob = self.pipe.predict_proba(X)
503 return y_prob
505 def predict_non_proba(self, X):
506 """
507 Predict the class non-probabilistically for a given set of features.
509 :param X: the features to predict on (array or recarry)
510 :return: the predicted probabilities
511 """
512 X = self._check_if_recarray(X)
513 y_pred = self.pipe.predict(X)
514 return y_pred
516 def test(self, X_test, y_test, non_proba=False):
517 y_pred = self.predict_non_proba(X_test) if non_proba else self.predict(X_test)
518 y_prob = self.predict_proba(X_test)
520 test_arr = clf_diagnostics.setup_test(multi_class=True)
521 clf_diagnostics.get_all_scores_multiclass(test_arr, y_test, y_pred, y_prob)
522 test_arr = at.dict2rec(test_arr)
523 self.test_scores = test_arr