Coverage for src/edelweiss/clf_utils.py: 100%
129 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
1# Copyright (C) 2022 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
5import numpy as np
6from catboost import CatBoostClassifier
7from cosmic_toolbox.logger import get_logger
8from lightgbm import LGBMClassifier
9from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
10from sklearn.ensemble import (AdaBoostClassifier, GradientBoostingClassifier,
11 RandomForestClassifier)
12from sklearn.linear_model import LogisticRegression
13from sklearn.metrics import make_scorer, roc_auc_score
14from sklearn.naive_bayes import GaussianNB
15from sklearn.neighbors import KNeighborsClassifier
16from sklearn.neural_network import MLPClassifier
17from sklearn.pipeline import Pipeline
18from sklearn.preprocessing import (MaxAbsScaler, MinMaxScaler,
19 QuantileTransformer, RobustScaler,
20 StandardScaler)
21from sklearn.svm import LinearSVC
22from sklearn.tree import DecisionTreeClassifier
23from xgboost import XGBClassifier
25LOGGER = get_logger(__file__)
28def get_clf_name(index=None):
29 """
30 Returns the name of the classifier file.
32 :param index: index of the classifier
33 :return: name of the classifier file
34 """
35 if index is None:
36 return None
37 return f"clf_cv/clf_{index}"
40def get_classifier(classifier, scaler=None, **kwargs):
41 """
42 Returns the classifier object
44 :param classifier: name of the classifier
45 :param scaler: scaler object
46 :param kwargs: additional arguments for the classifier
47 :return: classifier object (sklearn pipeline)
48 :raises: ValueError if classifier is not known
49 """
50 if scaler is None:
51 scaler = RobustScaler()
52 if classifier == "RandomForest":
53 clf = RandomForestClassifier(**kwargs)
54 elif classifier == "XGB":
55 clf = XGBClassifier(**kwargs)
56 elif classifier == "MLP":
57 clf = MLPClassifier(**kwargs)
58 elif classifier == "LogisticRegression":
59 clf = LogisticRegression(**kwargs)
60 elif classifier == "LinearSVC":
61 clf = LinearSVC(**kwargs)
62 elif classifier == "DecisionTree":
63 clf = DecisionTreeClassifier(**kwargs)
64 elif classifier == "AdaBoost":
65 clf = AdaBoostClassifier(**kwargs)
66 elif classifier == "KNN":
67 clf = KNeighborsClassifier(**kwargs)
68 elif classifier == "QDA":
69 clf = QuadraticDiscriminantAnalysis(**kwargs)
70 elif classifier == "GaussianNB":
71 clf = GaussianNB(**kwargs)
72 elif classifier == "NeuralNetwork":
73 # Import here to avoid tensorflow warnings when not using the classifier
74 from .custom_clfs import NeuralNetworkClassifier
76 clf = NeuralNetworkClassifier(**kwargs)
77 elif classifier == "GradientBoosting":
78 clf = GradientBoostingClassifier(**kwargs)
79 elif classifier == "CatBoost":
80 clf = CatBoostClassifier(**kwargs)
81 elif classifier == "LightGBM":
82 clf = LGBMClassifier(**kwargs)
83 else:
84 raise ValueError(f"{classifier} not known")
85 return Pipeline([("scaler", scaler), ("clf", clf)])
88def get_scaler(scaler):
89 """
90 Returns the scaler object
92 :param scaler: name of the scaler
93 :return: scaler object
94 :raises: ValueError if scaler is not known
95 """
97 if scaler == "standard":
98 return StandardScaler()
99 elif scaler == "minmax":
100 return MinMaxScaler()
101 elif scaler == "maxabs":
102 return MaxAbsScaler()
103 elif scaler == "robust":
104 return RobustScaler()
105 elif scaler == "quantile":
106 return QuantileTransformer()
107 else:
108 raise ValueError(f"{scaler} not known")
111def get_detection_label(clf, bands, n_detected_bands=None):
112 """
113 Get the detection label for the classifier.
115 :param clf: classification data (rec array)
116 :param bands: which bands the data has
117 :param n_detected_bands: how many bands have to be detected such that the event is
118 classified as detected, if None, the detection label is already given in clf
119 :return: detection label (bool array) and the names of the detection labels
120 """
121 det_labels = []
123 if n_detected_bands is None:
124 y = clf["detected"]
125 det_labels.append("detected")
126 return y, det_labels
128 y = np.zeros(len(clf))
129 for band in bands:
130 y += clf[f"detected_{band}"]
131 det_labels.append(f"detected_{band}")
132 return y >= n_detected_bands, det_labels
135def get_scorer(score, **kwargs):
136 """
137 Returns the scorer object given input string.
138 If not one of the known self defined scorers, returns the input string assuming
139 it is a sklearn scorer.
141 :param score: name of the scorer
142 :kwargs: additional arguments for the scorer
143 :return: scorer object
144 """
145 if score == "ngal":
146 return make_scorer(ngal_scorer, greater_is_better=False)
147 elif score == "roc_auc":
148 return make_scorer(custom_roc_auc_score, needs_proba=True)
149 else:
150 return score
153def load_hyperparams(clf):
154 """
155 Loads the hyperparameters for the classifier for the CV search.
157 :param clf: classifier object
158 :return: hyperparameter grid
159 """
161 if isinstance(clf, (LogisticRegression, LinearSVC)):
162 param_grid = {"clf__C": [0.1, 1, 10, 100]}
164 elif isinstance(clf, KNeighborsClassifier):
165 param_grid = {
166 "clf__n_neighbors": [5, 10, 100, 250, 500, 750],
167 "clf__weights": ["uniform", "distance"],
168 }
169 elif isinstance(clf, DecisionTreeClassifier):
170 param_grid = {
171 "clf__max_depth": [3, 5, 7, 9, 11],
172 "clf__min_samples_split": [2, 4, 6, 8, 10],
173 }
174 elif isinstance(clf, RandomForestClassifier):
175 param_grid = {
176 "clf__n_estimators": [20, 50, 100],
177 "clf__max_depth": [None, 10, 20, 30],
178 "clf__min_samples_split": [4, 6, 8, 10, 12],
179 }
180 elif isinstance(clf, XGBClassifier):
181 param_grid = {
182 "clf__learning_rate": [0.01, 0.1, 0.5, 1],
183 "clf__max_depth": [3, 5, 7, 9],
184 "clf__n_estimators": [5, 10, 50, 100],
185 }
186 elif isinstance(clf, MLPClassifier):
187 param_grid = {
188 "clf__hidden_layer_sizes": [
189 (10,),
190 (100,),
191 (250,),
192 (500,),
193 (750,),
194 ],
195 "clf__alpha": [0.001, 0.01, 0.1],
196 }
197 elif isinstance(clf, AdaBoostClassifier):
198 param_grid = {
199 "clf__n_estimators": [1000, 5000],
200 "clf__learning_rate": [0.01, 0.1],
201 }
202 elif isinstance(clf, QuadraticDiscriminantAnalysis):
203 param_grid = {"clf__reg_param": [0.0, 0.01, 0.1, 1, 10]}
204 elif isinstance(clf, GaussianNB):
205 param_grid = {}
206 elif isinstance(clf, GradientBoostingClassifier):
207 param_grid = {
208 "clf__n_estimators": [100, 500],
209 "clf__learning_rate": [0.01, 0.1],
210 "clf__max_depth": [3, 5, 7],
211 }
212 elif isinstance(clf, CatBoostClassifier):
213 param_grid = {
214 "clf__learning_rate": [0.03, 0.06],
215 "clf__depth": [3, 6, 9],
216 "clf__l2_leaf_reg": [2, 3, 4],
217 "clf__boosting_type": ["Ordered", "Plain"],
218 }
219 elif isinstance(clf, LGBMClassifier):
220 param_grid = {
221 "clf__num_leaves": [5, 20, 31],
222 "clf__learning_rate": [0.05, 0.1, 0.2],
223 "clf__n_estimators": [50, 100, 150],
224 }
225 else:
226 from .custom_clfs import NeuralNetworkClassifier
228 if isinstance(clf, NeuralNetworkClassifier):
229 param_grid = {
230 "clf__hidden_units": [(32, 64, 32), (512, 512, 512), (10, 10)],
231 "clf__learning_rate": [0.0001, 0.001],
232 "clf__epochs": [1000],
233 "clf__batch_size": [10000],
234 }
235 else:
236 LOGGER.warning(f"Classifier {clf} not known.")
237 param_grid = {}
239 return param_grid
242def ngal_scorer(y_true, y_pred):
243 """
244 Scorer accounting for the number of galaxies in the sample.
245 score = (N_pred - N_true)**2
247 :param y_true: true labels (detected or not)
248 :param y_pred: predicted labels (detected or not)
249 :return: score
250 """
251 return (sum(y_pred) - sum(y_true)) ** 2
254def custom_roc_auc_score(y_true, y_prob):
255 """
256 Scorer for the ROC AUC score using y_prob
258 :param y_true: true labels (detected or not)
259 :param y_prob: predicted probabilities (2D array)
260 :return: score
261 """
262 return roc_auc_score(y_true, y_prob[:, 1])
265def ngal_hist_scorer(y_true, y_pred, mag, bins=100, range=(15, 30)):
266 """
267 Scorer accounting for the number of galaxies in the sample on a histogram level.
268 score = (N_pred - N_true)**2
270 :param y_true: true labels (detected or not)
271 :param y_pred: predicted labels (detected or not)
272 :param mag: magnitude of the galaxies
273 :return: score
274 """
275 hist_true = np.histogram(mag[y_true], bins=bins, range=range)[0]
276 hist_pred = np.histogram(mag[y_pred], bins=bins, range=range)[0]
277 return (hist_pred - hist_true) ** 2
280def get_classifier_args(clf, conf):
281 """
282 Returns the arguments for the classifier defined in the config file
284 :param clf: classifier name
285 :param conf: config file
286 :return: arguments for the classifier
287 """
288 try:
289 return conf["classifier_args"][clf]
290 except KeyError:
291 LOGGER.warning(f"Classifier {clf} not found in config file.")
292 return {}