Coverage for src/edelweiss/classifier.py: 100%
306 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-07-31 10:21 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2025-07-31 10:21 +0000
1# Copyright (C) 2023 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
5import contextlib
6import os
7import pickle
8import warnings
10import joblib
11import numpy as np
12import sklearn
13from cosmic_toolbox import arraytools as at
14from cosmic_toolbox import file_utils, logger
15from packaging import version
16from sklearn.calibration import CalibratedClassifierCV
17from sklearn.model_selection import GridSearchCV
19from edelweiss import clf_diagnostics, clf_utils
20from edelweiss.compatibility_utils import fix_calibrated_classifier_compatibility
22LOGGER = logger.get_logger(__file__)
25def load_classifier(path, subfolder=None):
26 """
27 Load a classifier from a given path.
29 :param path: path to the folder containing the emulator
30 :param subfolder: subfolder of the emulator folder where the classifier is stored
31 :return: the loaded classifier
32 """
33 if subfolder is None:
34 subfolder = "clf"
35 output_directory = os.path.join(path, subfolder)
37 # Load classifier state
38 with open(os.path.join(output_directory, "clf.pkl"), "rb") as f:
39 clf = pickle.load(f)
41 # Load pipeline with scikit-learn version compatibility handling
42 with warnings.catch_warnings():
43 # Suppress InconsistentVersionWarning for known compatibility issues
44 warnings.filterwarnings(
45 "ignore", category=sklearn.exceptions.InconsistentVersionWarning
46 )
47 clf.pipe = joblib.load(os.path.join(output_directory, "model.pkl"))
49 # Fix calibrated classifier compatibility issues for sklearn >= 1.6
51 if hasattr(clf.pipe, "calibrated_classifiers_") and version.parse(
52 sklearn.__version__
53 ) >= version.parse("1.6"):
54 clf.pipe = fix_calibrated_classifier_compatibility(clf.pipe)
56 # Load TF models if needed
57 if clf.clf == "NeuralNetwork":
58 import tensorflow as tf
60 if clf.calibrate:
61 # Load calibrated models
62 for i, calibrated_clf in enumerate(clf.pipe.calibrated_classifiers_):
63 try:
64 # First try loading as h5 (Keras 2 format)
65 model_path = os.path.join(
66 output_directory, f"tf_calibrated_model_{i}.h5"
67 )
68 if os.path.exists(model_path):
69 model = tf.keras.models.load_model(model_path)
70 else: # pragma: no cover
71 # Try Keras 3 format
72 model = tf.keras.models.load_model(
73 os.path.join(
74 output_directory, f"tf_calibrated_model_{i}.keras"
75 )
76 )
77 except Exception as e: # pragma: no cover
78 LOGGER.error(f"Failed to load model: {e}")
79 raise
81 if clf.cv > 1:
82 calibrated_clf.estimator.best_estimator_.named_steps[
83 "clf"
84 ].model = model
85 else:
86 calibrated_clf.estimator.named_steps["clf"].model = model
87 else:
88 try:
89 # First try loading as h5 (Keras 2 format)
90 model_path = os.path.join(output_directory, "tf_model.h5")
91 if os.path.exists(model_path):
92 model = tf.keras.models.load_model(model_path)
93 else: # pragma: no cover
94 # Try Keras 3 format
95 model = tf.keras.models.load_model(
96 os.path.join(output_directory, "tf_model.keras")
97 )
98 except Exception as e: # pragma: no cover
99 LOGGER.error(f"Failed to load model: {e}")
100 raise
102 if clf.cv > 1:
103 clf.pipe.best_estimator_.named_steps["clf"].model = model
104 else:
105 clf.pipe.named_steps["clf"].model = model
107 return clf
110class Classifier:
111 """
112 The detection classifer class that wraps a sklearn classifier.
114 :param scaler: the scaler to use for the classifier, options: standard, minmax,
115 maxabs, robust, quantile
116 :param clf: the classifier to use, options are: XGB, MLP, RandomForest,
117 NeuralNetwork, LogisticRegression, LinearSVC, DecisionTree, AdaBoost,
118 GaussianNB, QDA, KNN,
119 :param calibrate: whether to calibrate the probabilities
120 :param cv: number of cross validation folds, if 0 no cross validation is performed
121 :param cv_scoring: the scoring method to use for cross validation
122 :param params: the names of the parameters
123 :param clf_kwargs: additional keyword arguments for the classifier
124 """
126 def __init__(
127 self,
128 scaler="standard",
129 clf="XGB",
130 calibrate=True,
131 cv=0,
132 cv_scoring="f1",
133 params=None,
134 **clf_kwargs,
135 ):
136 """
137 Initialize the classifier.
138 """
139 self.scaler = scaler
140 self.clf = clf
141 sc = clf_utils.get_scaler(scaler)
142 self.pipe = clf_utils.get_classifier(clf, sc, **clf_kwargs)
143 self.calibrate = calibrate
144 self.cv = cv
145 self.cv_scoring = cv_scoring
146 self.params = params
147 self.test_scores = None
149 def train(self, X, y, param_grid=None, **args):
150 """
151 Train the classifier.
153 :param X: the features to train on (array or recarray)
154 :param y: the labels to train on
155 :param param_grid: the hyperparameter grid to search over
156 :param args: additional arguments for the classifier
157 """
158 X = self._check_if_recarray(X)
160 if self.params is None:
161 self.params = np.arange(X.shape[1])
162 LOGGER.warning("No parameter names provided, numbers are used instead")
163 else:
164 assert len(self.params) == X.shape[1], (
165 "Number of parameters in training data does not match number"
166 " of parameters provided before"
167 )
169 LOGGER.info("Training this model:")
170 if self.calibrate:
171 LOGGER.info("CalibratedClassifierCV")
172 clf_names = self.pipe.named_steps.items()
173 for name, estimator in clf_names:
174 LOGGER.info(f"{name}:")
175 LOGGER.info(estimator)
176 LOGGER.info(f"number of samples: {X.shape[0]}")
177 LOGGER.info("-------------------")
179 # tune hyperparameters with grid search
180 if self.cv > 1:
181 LOGGER.info("Start cross validation")
182 if param_grid is None:
183 param_grid = clf_utils.load_hyperparams(self.pipe["clf"])
184 scorer = clf_utils.get_scorer(self.cv_scoring)
186 # Set up the grid search
187 if "SLURM_CPUS_PER_TASK" in os.environ: # pragma: no cover
188 n_jobs = max(int(os.environ["SLURM_CPUS_PER_TASK"]) // 4, 1)
189 else:
190 n_jobs = 1
191 LOGGER.info(f"Running the Grid search on {n_jobs} jobs")
192 self.pipe = GridSearchCV(
193 estimator=self.pipe,
194 param_grid=param_grid,
195 scoring=scorer,
196 cv=self.cv,
197 n_jobs=n_jobs,
198 )
200 if self.calibrate:
201 self.pipe = CalibratedClassifierCV(
202 self.pipe, cv=self.cv, method="isotonic"
203 )
205 # Run the grid search
206 self.pipe.fit(X, y, **args)
208 if self.calibrate:
209 best_params = self.pipe.calibrated_classifiers_[
210 0
211 ].estimator.best_params_
212 else:
213 best_params = self.pipe.best_params_
214 LOGGER.info("Best parameters found by grid search: %s", best_params)
215 else:
216 if self.calibrate:
217 self.pipe = CalibratedClassifierCV(self.pipe, cv=2, method="isotonic")
218 self.pipe.fit(X, y, **args)
219 self._get_feature_importance()
220 self._get_summed_feature_importance()
221 LOGGER.info("Training completed")
223 fit = train
225 def predict(self, X, prob_multiplier=1.0):
226 """
227 Predict the labels for a given set of features.
229 :param X: the features to predict on (array or recarry)
230 :return: the predicted labels
231 """
232 X = self._check_if_recarray(X)
233 y_prob = self.pipe.predict_proba(X)[:, 1] * prob_multiplier
234 y_prob = np.clip(y_prob, 0, 1)
235 y_pred = y_prob > np.random.rand(len(y_prob))
236 return y_pred
238 def predict_proba(self, X):
239 """
240 Predict the probabilities for a given set of features.
242 :param X: the features to predict on (array or recarry)
243 :return: the predicted probabilities
244 """
245 X = self._check_if_recarray(X)
246 y_prob = self.pipe.predict_proba(X)[:, 1]
247 return y_prob
249 def predict_non_proba(self, X):
250 """
251 Predict the probabilities for a given set of features.
253 :param X: the features to predict on (array or recarry)
254 :return: the predicted probabilities
255 """
256 X = self._check_if_recarray(X)
257 y_pred = self.pipe.predict(X)
258 return y_pred.astype(bool)
260 __call__ = predict
262 def save(self, path, subfolder=None):
263 """Save the classifier to a given path."""
264 if subfolder is None:
265 subfolder = "clf"
266 output_directory = os.path.join(path, subfolder)
267 file_utils.robust_makedirs(output_directory)
269 # Create clean pipeline copy without TF models
270 clean_pipe = self._create_clean_pipeline()
272 # Save pipeline
273 joblib.dump(clean_pipe, os.path.join(output_directory, "model.pkl"))
275 # Save TF models if needed
276 if self.clf == "NeuralNetwork":
277 if self.calibrate:
278 self._save_calibrated_tf_models(output_directory)
279 else:
280 # For GridSearchCV, save the best model
281 if self.cv > 1:
282 base_model = self.pipe.best_estimator_.named_steps["clf"].model
283 else:
284 base_model = self.pipe.named_steps["clf"].model
286 # Save in both formats for compatibility
287 base_model.save(os.path.join(output_directory, "tf_model.h5"))
288 with contextlib.suppress(Exception):
289 base_model.save(
290 os.path.join(output_directory, "tf_model.keras"),
291 save_format="keras_v3",
292 )
294 # Save classifier state
295 self.pipe = None
296 with open(os.path.join(output_directory, "clf.pkl"), "wb") as f:
297 pickle.dump(self, f)
298 LOGGER.info(f"Classifier saved to {output_directory}")
300 def _create_clean_pipeline(self):
301 """Create copy of pipeline without TF models."""
302 import copy
304 # Store model references
305 if self.clf == "NeuralNetwork":
306 # Temporarily remove models
307 if self.calibrate:
308 models = []
309 for clf in self.pipe.calibrated_classifiers_:
310 if self.cv > 1:
311 model = clf.estimator.best_estimator_.named_steps["clf"].model
312 clf.estimator.best_estimator_.named_steps["clf"].model = None
313 else:
314 model = clf.estimator.named_steps["clf"].model
315 clf.estimator.named_steps["clf"].model = None
316 models.append(model)
317 else:
318 if self.cv > 1:
319 model = self.pipe.best_estimator_.named_steps["clf"].model
320 self.pipe.best_estimator_.named_steps["clf"].model = None
321 else:
322 model = self.pipe.named_steps["clf"].model
323 self.pipe.named_steps["clf"].model = None
325 # deepcopy without tf models
326 clean_pipe = copy.deepcopy(self.pipe)
328 # Restore original models
329 if self.clf == "NeuralNetwork":
330 if self.calibrate:
331 for i, clf in enumerate(self.pipe.calibrated_classifiers_):
332 if self.cv > 1:
333 clf.estimator.best_estimator_.named_steps["clf"].model = models[
334 i
335 ]
336 else:
337 clf.estimator.named_steps["clf"].model = models[i]
338 else:
339 if self.cv > 1:
340 self.pipe.best_estimator_.named_steps["clf"].model = model
341 else:
342 self.pipe.named_steps["clf"].model = model
344 return clean_pipe
346 def _save_calibrated_tf_models(self, output_directory):
347 """Save calibrated TF models separately."""
348 for i, clf in enumerate(self.pipe.calibrated_classifiers_):
349 if self.cv > 1:
350 model = clf.estimator.best_estimator_.named_steps["clf"].model
351 else:
352 model = clf.estimator.named_steps["clf"].model
354 # Save in both formats for compatibility
355 model.save(os.path.join(output_directory, f"tf_calibrated_model_{i}.h5"))
356 with contextlib.suppress(Exception):
357 model.save(
358 os.path.join(output_directory, f"tf_calibrated_model_{i}.keras"),
359 save_format="keras_v3",
360 )
362 def test(self, X_test, y_test, non_proba=False):
363 """
364 Tests the classifier on the test data
366 :param test_arr: dict where the test scores will be saved
367 :param clf: classifier
368 :param X_test: test data
369 :param y_test: test labels
370 :param non_proba: whether to use non-probabilistic predictions
371 """
373 # get probability of being detected
374 y_prob = self.predict_proba(X_test)
375 y_pred = self.predict_non_proba(X_test) if non_proba else self.predict(X_test)
376 test_arr = clf_diagnostics.setup_test()
377 clf_diagnostics.get_all_scores(test_arr, y_test, y_pred, y_prob)
378 test_arr = at.dict2rec(test_arr)
379 self.test_scores = test_arr
381 def _check_if_recarray(self, X):
382 try:
383 X, names = at.rec2arr(X, return_names=True)
384 if self.params is None:
385 self.params = names
386 else:
387 assert np.all(
388 names == self.params
389 ), "Input parameters do not match the trained parameters"
390 return X
391 except Exception:
392 return X
394 def _get_feature_importance(self):
395 with contextlib.suppress(Exception):
396 # Try to get the feature importances if clf is GridSearchCV
397 importances = self.pipe.best_estimator_["clf"].feature_importances_
398 self.feature_importances = at.arr2rec(importances, self.params)
399 return
400 try:
401 # Try to get the feature importances if clf is CalibratedClassifierCV
402 importances = self.pipe.calibrated_classifiers_[
403 0
404 ].estimator._final_estimator.feature_importances_
405 self.feature_importances = at.arr2rec(importances, self.params)
406 return
407 except Exception:
408 with contextlib.suppress(Exception):
409 # Try to get the feature importances if clf is CalibratedClassifierCV
410 # and GridSearchCV
411 importances = (
412 self.pipe.calibrated_classifiers_[0]
413 .estimator.best_estimator_.named_steps["clf"]
414 .feature_importances_
415 )
416 self.feature_importances = at.arr2rec(importances, self.params)
417 return
418 try:
419 # Try to get the feature importances if clf is not GridSearchCV
420 importances = self.pipe["clf"].feature_importances_
421 self.feature_importances = at.arr2rec(importances, self.params)
422 return
423 except Exception:
424 self.feature_importances = None
426 def _get_summed_feature_importance(self):
427 """
428 Sum the feature importances of the same parameter across different bands.
429 Should be run after _get_feature_importance.
430 """
431 feature_importances = self.feature_importances
432 if feature_importances is None:
433 self.summed_feature_importances = None
434 return
436 # Create a dictionary to store the summed values based on modified prefixes
437 summed_features = {}
439 # Iterate through the feature importances, remove last character if suffix
440 # length is 1
441 for key in feature_importances.dtype.names:
442 parts = key.split("_")
443 prefix = (
444 "_".join(parts[:-1]) if len(parts[-1]) == 1 and len(parts) > 1 else key
445 ) # Remove the last character if suffix length is 1
447 if prefix not in summed_features:
448 summed_features[prefix] = 0.0
450 summed_features[prefix] += feature_importances[key]
452 self.summed_feature_importances = at.dict2rec(summed_features)
455def load_multiclassifier(path, subfolder=None):
456 """
457 Load a multiclassifier from a given path.
459 :param path: path to the folder containing the emulator
460 :param subfolder: subfolder of the emulator folder where the classifier is stored
461 :return: the loaded classifier
462 """
463 if subfolder is None:
464 subfolder = "clf"
465 output_directory = os.path.join(path, subfolder)
466 with open(os.path.join(output_directory, "clf.pkl"), "rb") as f:
467 clf = pickle.load(f)
468 clf.pipe = []
469 for label in clf.labels:
470 clf.pipe.append(load_classifier(path, subfolder=f"{subfolder}_{label}"))
471 LOGGER.debug(f"Classifier loaded from {output_directory}")
472 return clf
475class MultiClassifier:
476 """
477 A classifier class that trains multiple classifiers for a specific label. This label
478 could e.g. be the galaxy type (star, red galaxy, blue galaxy).
480 :param split_label: the label to split the data in different classifers
481 :param labels: the different labels of the split label
482 :param scaler: the scaler to use for the classifier
483 :param clf: the classifier to use
484 :param calibrate: whether to calibrate the probabilities
485 :param cv: number of cross validation folds, if 0 no cross validation is performed
486 :param cv_scoring: the scoring method to use for cross validation
487 :param params: the names of the parameters
488 :param clf_kwargs: additional keyword arguments for the classifier
489 """
491 def __init__(
492 self,
493 split_label="galaxy_type",
494 labels=None,
495 scaler="standard",
496 clf="XGB",
497 calibrate=True,
498 cv=0,
499 cv_scoring="f1",
500 params=None,
501 **clf_kwargs,
502 ):
503 """
504 Initialize the classifier.
505 """
506 if labels is None:
507 labels = [-1, 0, 1]
508 self.split_label = split_label
509 self.labels = labels
510 self.scaler = scaler
511 self.clf = clf
512 self.pipe = [
513 Classifier(
514 scaler=scaler,
515 clf=clf,
516 calibrate=calibrate,
517 cv=cv,
518 cv_scoring=cv_scoring,
519 params=params,
520 **clf_kwargs,
521 )
522 for _ in self.labels
523 ]
525 def train(self, X, y):
526 """
527 Train the classifier.
528 """
529 # TODO: dirty hack, fix this
530 self.params = X.dtype.names
531 for i, label in enumerate(self.labels):
532 idx = X[self.split_label] == label
533 X_ = at.delete_cols(X[idx], self.split_label)
534 self.pipe[i].train(X_, y[idx])
536 fit = train
538 def predict(self, X):
539 """
540 Predict the labels for a given set of features.
541 """
542 y_pred = np.zeros(len(X), dtype=bool)
543 for i, label in enumerate(self.labels):
544 idx = X[self.split_label] == label
545 if np.sum(idx) == 0:
546 continue
547 X_ = at.delete_cols(X[idx], self.split_label)
548 y_pred[idx] = self.pipe[i].predict(X_)
549 return y_pred
551 def predict_proba(self, X):
552 """
553 Predict the probabilities for a given set of features.
554 """
555 y_prob = np.zeros(len(X), dtype=float)
556 for i, label in enumerate(self.labels):
557 idx = X[self.split_label] == label
558 if np.sum(idx) == 0:
559 continue
560 y_prob[idx] = self.pipe[i].predict_proba(X[idx])
561 return y_prob
563 def predict_non_proba(self, X):
564 """
565 Predict the probabilities for a given set of features.
566 """
567 y_pred = np.zeros(len(X), dtype=bool)
568 for i, label in enumerate(self.labels):
569 idx = X[self.split_label] == label
570 if np.sum(idx) == 0:
571 continue
572 X_ = at.delete_cols(X[idx], self.split_label)
573 y_pred[idx] = self.pipe[i].predict_non_proba(X_)
574 return y_pred
576 __call__ = predict
578 def save(self, path, subfolder=None):
579 """
580 Save the classifier to a given path.
582 :param path: path to the folder where the emulator is saved
583 :param subfolder: subfolder of the emulator folder where the classifier is
584 stored
585 """
586 if subfolder is None:
587 subfolder = "clf"
588 output_directory = os.path.join(path, subfolder)
589 file_utils.robust_makedirs(output_directory)
590 for i, label in enumerate(self.labels):
591 self.pipe[i].save(path, subfolder=f"{subfolder}_{label}")
592 self.pipe = None
593 with open(os.path.join(path, subfolder, "clf.pkl"), "wb") as f:
594 pickle.dump(self, f)
595 LOGGER.info(f"MultiClassifier saved to {os.path.join(path, subfolder)}")
597 def test(self, X_test, y_test, non_proba=False):
598 """
599 Tests the classifier on the test data
601 :param test_arr: dict where the test scores will be saved
602 :param clf: classifier
603 :param X_test: test data
604 :param y_test: test labels
605 :param non_proba: whether to use non-probabilistic predictions
606 """
607 # get probability of being detected
608 y_prob = self.predict_proba(X_test)
609 y_pred = self.predict_non_proba(X_test) if non_proba else self.predict(X_test)
610 test_arr = clf_diagnostics.setup_test()
611 clf_diagnostics.get_all_scores(test_arr, y_test, y_pred, y_prob)
612 test_arr = at.dict2rec(test_arr)
613 self.test_scores = test_arr
616class MultiClassClassifier(Classifier):
617 """
618 The detection classifer class that wraps a sklearn classifier for multiple classes.
620 :param scaler: the scaler to use for the classifier, options: standard, minmax,
621 maxabs, robust, quantile
622 :param clf: the classifier to use, options are: XGB, MLP, RandomForest,
623 NeuralNetwork, LogisticRegression, LinearSVC, DecisionTree, AdaBoost,
624 GaussianNB, QDA, KNN,
625 :param calibrate: whether to calibrate the probabilities
626 :param cv: number of cross validation folds, if 0 no cross validation is performed
627 :param cv_scoring: the scoring method to use for cross validation
628 :param params: the names of the parameters
629 :param clf_kwargs: additional keyword arguments for the classifier
630 """
632 def predict(self, X):
633 """
634 Predict the labels for a given set of features.
636 :param X: the features to predict on (array or recarry)
637 :return: the predicted labels
638 """
639 X = self._check_if_recarray(X)
640 y_prob = self.pipe.predict_proba(X)
641 y_pred = np.array([np.random.choice(len(prob), p=prob) for prob in y_prob])
643 return y_pred
645 def predict_proba(self, X):
646 """
647 Predict the probabilities for a given set of features.
649 :param X: the features to predict on (array or recarry)
650 :return: the predicted probabilities
651 """
652 X = self._check_if_recarray(X)
653 y_prob = self.pipe.predict_proba(X)
654 return y_prob
656 def predict_non_proba(self, X):
657 """
658 Predict the class non-probabilistically for a given set of features.
660 :param X: the features to predict on (array or recarry)
661 :return: the predicted probabilities
662 """
663 X = self._check_if_recarray(X)
664 y_pred = self.pipe.predict(X)
665 return y_pred
667 def test(self, X_test, y_test, non_proba=False):
668 y_pred = self.predict_non_proba(X_test) if non_proba else self.predict(X_test)
669 y_prob = self.predict_proba(X_test)
671 test_arr = clf_diagnostics.setup_test(multi_class=True)
672 clf_diagnostics.get_all_scores_multiclass(test_arr, y_test, y_pred, y_prob)
673 test_arr = at.dict2rec(test_arr)
674 self.test_scores = test_arr