Coverage for src/edelweiss/clf_utils.py: 100%

129 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-18 17:09 +0000

1# Copyright (C) 2022 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4 

5import numpy as np 

6from catboost import CatBoostClassifier 

7from cosmic_toolbox.logger import get_logger 

8from lightgbm import LGBMClassifier 

9from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis 

10from sklearn.ensemble import (AdaBoostClassifier, GradientBoostingClassifier, 

11 RandomForestClassifier) 

12from sklearn.linear_model import LogisticRegression 

13from sklearn.metrics import make_scorer, roc_auc_score 

14from sklearn.naive_bayes import GaussianNB 

15from sklearn.neighbors import KNeighborsClassifier 

16from sklearn.neural_network import MLPClassifier 

17from sklearn.pipeline import Pipeline 

18from sklearn.preprocessing import (MaxAbsScaler, MinMaxScaler, 

19 QuantileTransformer, RobustScaler, 

20 StandardScaler) 

21from sklearn.svm import LinearSVC 

22from sklearn.tree import DecisionTreeClassifier 

23from xgboost import XGBClassifier 

24 

25LOGGER = get_logger(__file__) 

26 

27 

28def get_clf_name(index=None): 

29 """ 

30 Returns the name of the classifier file. 

31 

32 :param index: index of the classifier 

33 :return: name of the classifier file 

34 """ 

35 if index is None: 

36 return None 

37 return f"clf_cv/clf_{index}" 

38 

39 

40def get_classifier(classifier, scaler=None, **kwargs): 

41 """ 

42 Returns the classifier object 

43 

44 :param classifier: name of the classifier 

45 :param scaler: scaler object 

46 :param kwargs: additional arguments for the classifier 

47 :return: classifier object (sklearn pipeline) 

48 :raises: ValueError if classifier is not known 

49 """ 

50 if scaler is None: 

51 scaler = RobustScaler() 

52 if classifier == "RandomForest": 

53 clf = RandomForestClassifier(**kwargs) 

54 elif classifier == "XGB": 

55 clf = XGBClassifier(**kwargs) 

56 elif classifier == "MLP": 

57 clf = MLPClassifier(**kwargs) 

58 elif classifier == "LogisticRegression": 

59 clf = LogisticRegression(**kwargs) 

60 elif classifier == "LinearSVC": 

61 clf = LinearSVC(**kwargs) 

62 elif classifier == "DecisionTree": 

63 clf = DecisionTreeClassifier(**kwargs) 

64 elif classifier == "AdaBoost": 

65 clf = AdaBoostClassifier(**kwargs) 

66 elif classifier == "KNN": 

67 clf = KNeighborsClassifier(**kwargs) 

68 elif classifier == "QDA": 

69 clf = QuadraticDiscriminantAnalysis(**kwargs) 

70 elif classifier == "GaussianNB": 

71 clf = GaussianNB(**kwargs) 

72 elif classifier == "NeuralNetwork": 

73 # Import here to avoid tensorflow warnings when not using the classifier 

74 from .custom_clfs import NeuralNetworkClassifier 

75 

76 clf = NeuralNetworkClassifier(**kwargs) 

77 elif classifier == "GradientBoosting": 

78 clf = GradientBoostingClassifier(**kwargs) 

79 elif classifier == "CatBoost": 

80 clf = CatBoostClassifier(**kwargs) 

81 elif classifier == "LightGBM": 

82 clf = LGBMClassifier(**kwargs) 

83 else: 

84 raise ValueError(f"{classifier} not known") 

85 return Pipeline([("scaler", scaler), ("clf", clf)]) 

86 

87 

88def get_scaler(scaler): 

89 """ 

90 Returns the scaler object 

91 

92 :param scaler: name of the scaler 

93 :return: scaler object 

94 :raises: ValueError if scaler is not known 

95 """ 

96 

97 if scaler == "standard": 

98 return StandardScaler() 

99 elif scaler == "minmax": 

100 return MinMaxScaler() 

101 elif scaler == "maxabs": 

102 return MaxAbsScaler() 

103 elif scaler == "robust": 

104 return RobustScaler() 

105 elif scaler == "quantile": 

106 return QuantileTransformer() 

107 else: 

108 raise ValueError(f"{scaler} not known") 

109 

110 

111def get_detection_label(clf, bands, n_detected_bands=None): 

112 """ 

113 Get the detection label for the classifier. 

114 

115 :param clf: classification data (rec array) 

116 :param bands: which bands the data has 

117 :param n_detected_bands: how many bands have to be detected such that the event is 

118 classified as detected, if None, the detection label is already given in clf 

119 :return: detection label (bool array) and the names of the detection labels 

120 """ 

121 det_labels = [] 

122 

123 if n_detected_bands is None: 

124 y = clf["detected"] 

125 det_labels.append("detected") 

126 return y, det_labels 

127 

128 y = np.zeros(len(clf)) 

129 for band in bands: 

130 y += clf[f"detected_{band}"] 

131 det_labels.append(f"detected_{band}") 

132 return y >= n_detected_bands, det_labels 

133 

134 

135def get_scorer(score, **kwargs): 

136 """ 

137 Returns the scorer object given input string. 

138 If not one of the known self defined scorers, returns the input string assuming 

139 it is a sklearn scorer. 

140 

141 :param score: name of the scorer 

142 :kwargs: additional arguments for the scorer 

143 :return: scorer object 

144 """ 

145 if score == "ngal": 

146 return make_scorer(ngal_scorer, greater_is_better=False) 

147 elif score == "roc_auc": 

148 return make_scorer(custom_roc_auc_score, needs_proba=True) 

149 else: 

150 return score 

151 

152 

153def load_hyperparams(clf): 

154 """ 

155 Loads the hyperparameters for the classifier for the CV search. 

156 

157 :param clf: classifier object 

158 :return: hyperparameter grid 

159 """ 

160 

161 if isinstance(clf, (LogisticRegression, LinearSVC)): 

162 param_grid = {"clf__C": [0.1, 1, 10, 100]} 

163 

164 elif isinstance(clf, KNeighborsClassifier): 

165 param_grid = { 

166 "clf__n_neighbors": [5, 10, 100, 250, 500, 750], 

167 "clf__weights": ["uniform", "distance"], 

168 } 

169 elif isinstance(clf, DecisionTreeClassifier): 

170 param_grid = { 

171 "clf__max_depth": [3, 5, 7, 9, 11], 

172 "clf__min_samples_split": [2, 4, 6, 8, 10], 

173 } 

174 elif isinstance(clf, RandomForestClassifier): 

175 param_grid = { 

176 "clf__n_estimators": [20, 50, 100], 

177 "clf__max_depth": [None, 10, 20, 30], 

178 "clf__min_samples_split": [4, 6, 8, 10, 12], 

179 } 

180 elif isinstance(clf, XGBClassifier): 

181 param_grid = { 

182 "clf__learning_rate": [0.01, 0.1, 0.5, 1], 

183 "clf__max_depth": [3, 5, 7, 9], 

184 "clf__n_estimators": [5, 10, 50, 100], 

185 } 

186 elif isinstance(clf, MLPClassifier): 

187 param_grid = { 

188 "clf__hidden_layer_sizes": [ 

189 (10,), 

190 (100,), 

191 (250,), 

192 (500,), 

193 (750,), 

194 ], 

195 "clf__alpha": [0.001, 0.01, 0.1], 

196 } 

197 elif isinstance(clf, AdaBoostClassifier): 

198 param_grid = { 

199 "clf__n_estimators": [1000, 5000], 

200 "clf__learning_rate": [0.01, 0.1], 

201 } 

202 elif isinstance(clf, QuadraticDiscriminantAnalysis): 

203 param_grid = {"clf__reg_param": [0.0, 0.01, 0.1, 1, 10]} 

204 elif isinstance(clf, GaussianNB): 

205 param_grid = {} 

206 elif isinstance(clf, GradientBoostingClassifier): 

207 param_grid = { 

208 "clf__n_estimators": [100, 500], 

209 "clf__learning_rate": [0.01, 0.1], 

210 "clf__max_depth": [3, 5, 7], 

211 } 

212 elif isinstance(clf, CatBoostClassifier): 

213 param_grid = { 

214 "clf__learning_rate": [0.03, 0.06], 

215 "clf__depth": [3, 6, 9], 

216 "clf__l2_leaf_reg": [2, 3, 4], 

217 "clf__boosting_type": ["Ordered", "Plain"], 

218 } 

219 elif isinstance(clf, LGBMClassifier): 

220 param_grid = { 

221 "clf__num_leaves": [5, 20, 31], 

222 "clf__learning_rate": [0.05, 0.1, 0.2], 

223 "clf__n_estimators": [50, 100, 150], 

224 } 

225 else: 

226 from .custom_clfs import NeuralNetworkClassifier 

227 

228 if isinstance(clf, NeuralNetworkClassifier): 

229 param_grid = { 

230 "clf__hidden_units": [(32, 64, 32), (512, 512, 512), (10, 10)], 

231 "clf__learning_rate": [0.0001, 0.001], 

232 "clf__epochs": [1000], 

233 "clf__batch_size": [10000], 

234 } 

235 else: 

236 LOGGER.warning(f"Classifier {clf} not known.") 

237 param_grid = {} 

238 

239 return param_grid 

240 

241 

242def ngal_scorer(y_true, y_pred): 

243 """ 

244 Scorer accounting for the number of galaxies in the sample. 

245 score = (N_pred - N_true)**2 

246 

247 :param y_true: true labels (detected or not) 

248 :param y_pred: predicted labels (detected or not) 

249 :return: score 

250 """ 

251 return (sum(y_pred) - sum(y_true)) ** 2 

252 

253 

254def custom_roc_auc_score(y_true, y_prob): 

255 """ 

256 Scorer for the ROC AUC score using y_prob 

257 

258 :param y_true: true labels (detected or not) 

259 :param y_prob: predicted probabilities (2D array) 

260 :return: score 

261 """ 

262 return roc_auc_score(y_true, y_prob[:, 1]) 

263 

264 

265def ngal_hist_scorer(y_true, y_pred, mag, bins=100, range=(15, 30)): 

266 """ 

267 Scorer accounting for the number of galaxies in the sample on a histogram level. 

268 score = (N_pred - N_true)**2 

269 

270 :param y_true: true labels (detected or not) 

271 :param y_pred: predicted labels (detected or not) 

272 :param mag: magnitude of the galaxies 

273 :return: score 

274 """ 

275 hist_true = np.histogram(mag[y_true], bins=bins, range=range)[0] 

276 hist_pred = np.histogram(mag[y_pred], bins=bins, range=range)[0] 

277 return (hist_pred - hist_true) ** 2 

278 

279 

280def get_classifier_args(clf, conf): 

281 """ 

282 Returns the arguments for the classifier defined in the config file 

283 

284 :param clf: classifier name 

285 :param conf: config file 

286 :return: arguments for the classifier 

287 """ 

288 try: 

289 return conf["classifier_args"][clf] 

290 except KeyError: 

291 LOGGER.warning(f"Classifier {clf} not found in config file.") 

292 return {}