Coverage for src/edelweiss/classifier.py: 99%

237 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-18 17:09 +0000

1# Copyright (C) 2023 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4 

5import os 

6import pickle 

7 

8import joblib 

9import numpy as np 

10from cosmic_toolbox import arraytools as at 

11from cosmic_toolbox import file_utils, logger 

12from sklearn.calibration import CalibratedClassifierCV 

13from sklearn.model_selection import GridSearchCV 

14 

15from edelweiss import clf_diagnostics, clf_utils 

16 

17LOGGER = logger.get_logger(__file__) 

18 

19 

20def load_classifier(path, subfolder=None): 

21 """ 

22 Load a classifier from a given path. 

23 

24 :param path: path to the folder containing the emulator 

25 :param subfolder: subfolder of the emulator folder where the classifier is stored 

26 :return: the loaded classifier 

27 """ 

28 if subfolder is None: 

29 subfolder = "clf" 

30 output_directory = os.path.join(path, subfolder) 

31 with open(os.path.join(output_directory, "clf.pkl"), "rb") as f: 

32 clf = pickle.load(f) 

33 clf.pipe = joblib.load(os.path.join(output_directory, "model.pkl")) 

34 LOGGER.debug(f"Classifier loaded from {output_directory}") 

35 return clf 

36 

37 

38class Classifier: 

39 """ 

40 The detection classifer class that wraps a sklearn classifier. 

41 

42 :param scaler: the scaler to use for the classifier, options: standard, minmax, 

43 maxabs, robust, quantile 

44 :param clf: the classifier to use, options are: XGB, MLP, RandomForest, 

45 NeuralNetwork, LogisticRegression, LinearSVC, DecisionTree, AdaBoost, 

46 GaussianNB, QDA, KNN, 

47 :param calibrate: whether to calibrate the probabilities 

48 :param cv: number of cross validation folds, if 0 no cross validation is performed 

49 :param cv_scoring: the scoring method to use for cross validation 

50 :param params: the names of the parameters 

51 :param clf_kwargs: additional keyword arguments for the classifier 

52 """ 

53 

54 def __init__( 

55 self, 

56 scaler="standard", 

57 clf="XGB", 

58 calibrate=True, 

59 cv=0, 

60 cv_scoring="f1", 

61 params=None, 

62 **clf_kwargs, 

63 ): 

64 """ 

65 Initialize the classifier. 

66 """ 

67 self.scaler = scaler 

68 self.clf = clf 

69 sc = clf_utils.get_scaler(scaler) 

70 self.pipe = clf_utils.get_classifier(clf, sc, **clf_kwargs) 

71 self.calibrate = calibrate 

72 self.cv = cv 

73 self.cv_scoring = cv_scoring 

74 self.params = params 

75 self.test_scores = None 

76 

77 def train(self, X, y, **args): 

78 """ 

79 Train the classifier. 

80 

81 :param X: the features to train on (array or recarray) 

82 :param y: the labels to train on 

83 :param args: additional arguments for the classifier 

84 """ 

85 X = self._check_if_recarray(X) 

86 

87 if self.params is None: 

88 self.params = np.arange(X.shape[1]) 

89 LOGGER.warning("No parameter names provided, numbers are used instead") 

90 else: 

91 assert len(self.params) == X.shape[1], ( 

92 "Number of parameters in training data does not match number" 

93 " of parameters provided before" 

94 ) 

95 

96 LOGGER.info("Training this model:") 

97 if self.calibrate: 

98 LOGGER.info("CalibratedClassifierCV") 

99 clf_names = self.pipe.named_steps.items() 

100 for name, estimator in clf_names: 

101 LOGGER.info(f"{name}:") 

102 LOGGER.info(estimator) 

103 LOGGER.info(f"number of samples: {X.shape[0]}") 

104 LOGGER.info("-------------------") 

105 

106 # tune hyperparameters with grid search 

107 if self.cv > 1: 

108 LOGGER.info("Start cross validation") 

109 param_grid = clf_utils.load_hyperparams(self.pipe["clf"]) 

110 scorer = clf_utils.get_scorer(self.cv_scoring) 

111 

112 # Set up the grid search 

113 if "SLURM_CPUS_PER_TASK" in os.environ: # pragma: no cover 

114 n_jobs = max(int(os.environ["SLURM_CPUS_PER_TASK"]) // 4, 1) 

115 else: 

116 n_jobs = 1 

117 LOGGER.info(f"Running the Grid search on {n_jobs} jobs") 

118 self.pipe = GridSearchCV( 

119 estimator=self.pipe, 

120 param_grid=param_grid, 

121 scoring=scorer, 

122 cv=self.cv, 

123 n_jobs=n_jobs, 

124 ) 

125 

126 if self.calibrate: 

127 self.pipe = CalibratedClassifierCV( 

128 self.pipe, cv=self.cv, method="isotonic" 

129 ) 

130 

131 # Run the grid search 

132 self.pipe.fit(X, y, **args) 

133 

134 if self.calibrate: 

135 best_params = self.pipe.calibrated_classifiers_[ 

136 0 

137 ].estimator.best_params_ 

138 else: 

139 best_params = self.pipe.best_params_ 

140 LOGGER.info("Best parameters found by grid search: %s", best_params) 

141 else: 

142 if self.calibrate: 

143 self.pipe = CalibratedClassifierCV(self.pipe, cv=2, method="isotonic") 

144 self.pipe.fit(X, y, **args) 

145 self._get_feature_importance() 

146 self._get_summed_feature_importance() 

147 LOGGER.info("Training completed") 

148 

149 fit = train 

150 

151 def predict(self, X, prob_multiplier=1.0): 

152 """ 

153 Predict the labels for a given set of features. 

154 

155 :param X: the features to predict on (array or recarry) 

156 :return: the predicted labels 

157 """ 

158 X = self._check_if_recarray(X) 

159 y_prob = self.pipe.predict_proba(X)[:, 1] * prob_multiplier 

160 y_prob = np.clip(y_prob, 0, 1) 

161 y_pred = y_prob > np.random.rand(len(y_prob)) 

162 return y_pred 

163 

164 def predict_proba(self, X): 

165 """ 

166 Predict the probabilities for a given set of features. 

167 

168 :param X: the features to predict on (array or recarry) 

169 :return: the predicted probabilities 

170 """ 

171 X = self._check_if_recarray(X) 

172 y_prob = self.pipe.predict_proba(X)[:, 1] 

173 return y_prob 

174 

175 def predict_non_proba(self, X): 

176 """ 

177 Predict the probabilities for a given set of features. 

178 

179 :param X: the features to predict on (array or recarry) 

180 :return: the predicted probabilities 

181 """ 

182 X = self._check_if_recarray(X) 

183 y_pred = self.pipe.predict(X) 

184 return y_pred.astype(bool) 

185 

186 __call__ = predict 

187 

188 def save(self, path, subfolder=None): 

189 """ 

190 Save the classifier to a given path. 

191 

192 :param path: path to the folder where the emulator is saved 

193 :param subfolder: subfolder of the emulator folder where the classifier is 

194 stored 

195 """ 

196 

197 if subfolder is None: 

198 subfolder = "clf" 

199 output_directory = os.path.join(path, subfolder) 

200 file_utils.robust_makedirs(output_directory) 

201 joblib.dump(self.pipe, os.path.join(output_directory, "model.pkl")) 

202 self.pipe = None 

203 with open(os.path.join(output_directory, "clf.pkl"), "wb") as f: 

204 pickle.dump(self, f) 

205 LOGGER.info(f"Classifier saved to {output_directory}") 

206 

207 def test(self, X_test, y_test, non_proba=False): 

208 """ 

209 Tests the classifier on the test data 

210 

211 :param test_arr: dict where the test scores will be saved 

212 :param clf: classifier 

213 :param X_test: test data 

214 :param y_test: test labels 

215 :param non_proba: whether to use non-probabilistic predictions 

216 """ 

217 

218 # get probability of being detected 

219 y_prob = self.predict_proba(X_test) 

220 y_pred = self.predict_non_proba(X_test) if non_proba else self.predict(X_test) 

221 test_arr = clf_diagnostics.setup_test() 

222 clf_diagnostics.get_all_scores(test_arr, y_test, y_pred, y_prob) 

223 test_arr = at.dict2rec(test_arr) 

224 self.test_scores = test_arr 

225 

226 def _check_if_recarray(self, X): 

227 try: 

228 X, names = at.rec2arr(X, return_names=True) 

229 if self.params is None: 

230 self.params = names 

231 else: 

232 assert np.all( 

233 names == self.params 

234 ), "Input parameters do not match the trained parameters" 

235 return X 

236 except Exception: 

237 return X 

238 

239 def _get_feature_importance(self): 

240 try: 

241 # Try to get the feature importances if clf is GridSearchCV 

242 importances = self.pipe.best_estimator_["clf"].feature_importances_ 

243 self.feature_importances = at.arr2rec(importances, self.params) 

244 return 

245 except Exception: 

246 pass 

247 try: 

248 # Try to get the feature importances if clf is CalibratedClassifierCV 

249 importances = self.pipe.calibrated_classifiers_[ 

250 0 

251 ].estimator._final_estimator.feature_importances_ 

252 self.feature_importances = at.arr2rec(importances, self.params) 

253 return 

254 except Exception: 

255 try: 

256 # Try to get the feature importances if clf is CalibratedClassifierCV 

257 # and GridSearchCV 

258 importances = ( 

259 self.pipe.calibrated_classifiers_[0] 

260 .estimator.best_estimator_.named_steps["clf"] 

261 .feature_importances_ 

262 ) 

263 self.feature_importances = at.arr2rec(importances, self.params) 

264 return 

265 except Exception: 

266 pass 

267 try: 

268 # Try to get the feature importances if clf is not GridSearchCV 

269 importances = self.pipe["clf"].feature_importances_ 

270 self.feature_importances = at.arr2rec(importances, self.params) 

271 return 

272 except Exception: 

273 self.feature_importances = None 

274 

275 def _get_summed_feature_importance(self): 

276 """ 

277 Sum the feature importances of the same parameter across different bands. 

278 Should be run after _get_feature_importance. 

279 """ 

280 feature_importances = self.feature_importances 

281 if feature_importances is None: 

282 self.summed_feature_importances = None 

283 return 

284 

285 # Create a dictionary to store the summed values based on modified prefixes 

286 summed_features = {} 

287 

288 # Iterate through the feature importances, remove last character if suffix 

289 # length is 1 

290 for key in feature_importances.dtype.names: 

291 parts = key.split("_") 

292 prefix = ( 

293 "_".join(parts[:-1]) if len(parts[-1]) == 1 and len(parts) > 1 else key 

294 ) # Remove the last character if suffix length is 1 

295 

296 if prefix not in summed_features: 

297 summed_features[prefix] = 0.0 

298 

299 summed_features[prefix] += feature_importances[key] 

300 

301 self.summed_feature_importances = at.dict2rec(summed_features) 

302 

303 

304def load_multiclassifier(path, subfolder=None): 

305 """ 

306 Load a multiclassifier from a given path. 

307 

308 :param path: path to the folder containing the emulator 

309 :param subfolder: subfolder of the emulator folder where the classifier is stored 

310 :return: the loaded classifier 

311 """ 

312 if subfolder is None: 

313 subfolder = "clf" 

314 output_directory = os.path.join(path, subfolder) 

315 with open(os.path.join(output_directory, "clf.pkl"), "rb") as f: 

316 clf = pickle.load(f) 

317 clf.pipe = [] 

318 for label in clf.labels: 

319 clf.pipe.append(load_classifier(path, subfolder=f"{subfolder}_{label}")) 

320 LOGGER.debug(f"Classifier loaded from {output_directory}") 

321 return clf 

322 

323 

324class MultiClassifier: 

325 """ 

326 A classifier class that trains multiple classifiers for a specific label. This label 

327 could e.g. be the galaxy type (star, red galaxy, blue galaxy). 

328 

329 :param split_label: the label to split the data in different classifers 

330 :param labels: the different labels of the split label 

331 :param scaler: the scaler to use for the classifier 

332 :param clf: the classifier to use 

333 :param calibrate: whether to calibrate the probabilities 

334 :param cv: number of cross validation folds, if 0 no cross validation is performed 

335 :param cv_scoring: the scoring method to use for cross validation 

336 :param params: the names of the parameters 

337 :param clf_kwargs: additional keyword arguments for the classifier 

338 """ 

339 

340 def __init__( 

341 self, 

342 split_label="galaxy_type", 

343 labels=None, 

344 scaler="standard", 

345 clf="XGB", 

346 calibrate=True, 

347 cv=0, 

348 cv_scoring="f1", 

349 params=None, 

350 **clf_kwargs, 

351 ): 

352 """ 

353 Initialize the classifier. 

354 """ 

355 if labels is None: 

356 labels = [-1, 0, 1] 

357 self.split_label = split_label 

358 self.labels = labels 

359 self.scaler = scaler 

360 self.clf = clf 

361 self.pipe = [ 

362 Classifier( 

363 scaler=scaler, 

364 clf=clf, 

365 calibrate=calibrate, 

366 cv=cv, 

367 cv_scoring=cv_scoring, 

368 params=params, 

369 **clf_kwargs, 

370 ) 

371 for _ in self.labels 

372 ] 

373 

374 def train(self, X, y): 

375 """ 

376 Train the classifier. 

377 """ 

378 # TODO: dirty hack, fix this 

379 self.params = X.dtype.names 

380 for i, label in enumerate(self.labels): 

381 idx = X[self.split_label] == label 

382 X_ = at.delete_cols(X[idx], self.split_label) 

383 self.pipe[i].train(X_, y[idx]) 

384 

385 fit = train 

386 

387 def predict(self, X): 

388 """ 

389 Predict the labels for a given set of features. 

390 """ 

391 y_pred = np.zeros(len(X), dtype=bool) 

392 for i, label in enumerate(self.labels): 

393 idx = X[self.split_label] == label 

394 if np.sum(idx) == 0: 

395 continue 

396 X_ = at.delete_cols(X[idx], self.split_label) 

397 y_pred[idx] = self.pipe[i].predict(X_) 

398 return y_pred 

399 

400 def predict_proba(self, X): 

401 """ 

402 Predict the probabilities for a given set of features. 

403 """ 

404 y_prob = np.zeros(len(X), dtype=float) 

405 for i, label in enumerate(self.labels): 

406 idx = X[self.split_label] == label 

407 if np.sum(idx) == 0: 

408 continue 

409 y_prob[idx] = self.pipe[i].predict_proba(X[idx]) 

410 return y_prob 

411 

412 def predict_non_proba(self, X): 

413 """ 

414 Predict the probabilities for a given set of features. 

415 """ 

416 y_pred = np.zeros(len(X), dtype=bool) 

417 for i, label in enumerate(self.labels): 

418 idx = X[self.split_label] == label 

419 if np.sum(idx) == 0: 

420 continue 

421 X_ = at.delete_cols(X[idx], self.split_label) 

422 y_pred[idx] = self.pipe[i].predict_non_proba(X_) 

423 return y_pred 

424 

425 __call__ = predict 

426 

427 def save(self, path, subfolder=None): 

428 """ 

429 Save the classifier to a given path. 

430 

431 :param path: path to the folder where the emulator is saved 

432 :param subfolder: subfolder of the emulator folder where the classifier is 

433 stored 

434 """ 

435 if subfolder is None: 

436 subfolder = "clf" 

437 output_directory = os.path.join(path, subfolder) 

438 file_utils.robust_makedirs(output_directory) 

439 for i, label in enumerate(self.labels): 

440 self.pipe[i].save(path, subfolder=f"{subfolder}_{label}") 

441 self.pipe = None 

442 with open(os.path.join(path, subfolder, "clf.pkl"), "wb") as f: 

443 pickle.dump(self, f) 

444 LOGGER.info(f"MultiClassifier saved to {os.path.join(path, subfolder)}") 

445 

446 def test(self, X_test, y_test, non_proba=False): 

447 """ 

448 Tests the classifier on the test data 

449 

450 :param test_arr: dict where the test scores will be saved 

451 :param clf: classifier 

452 :param X_test: test data 

453 :param y_test: test labels 

454 :param non_proba: whether to use non-probabilistic predictions 

455 """ 

456 # get probability of being detected 

457 y_prob = self.predict_proba(X_test) 

458 y_pred = self.predict_non_proba(X_test) if non_proba else self.predict(X_test) 

459 test_arr = clf_diagnostics.setup_test() 

460 clf_diagnostics.get_all_scores(test_arr, y_test, y_pred, y_prob) 

461 test_arr = at.dict2rec(test_arr) 

462 self.test_scores = test_arr 

463 

464 

465class MultiClassClassifier(Classifier): 

466 """ 

467 The detection classifer class that wraps a sklearn classifier for multiple classes. 

468 

469 :param scaler: the scaler to use for the classifier, options: standard, minmax, 

470 maxabs, robust, quantile 

471 :param clf: the classifier to use, options are: XGB, MLP, RandomForest, 

472 NeuralNetwork, LogisticRegression, LinearSVC, DecisionTree, AdaBoost, 

473 GaussianNB, QDA, KNN, 

474 :param calibrate: whether to calibrate the probabilities 

475 :param cv: number of cross validation folds, if 0 no cross validation is performed 

476 :param cv_scoring: the scoring method to use for cross validation 

477 :param params: the names of the parameters 

478 :param clf_kwargs: additional keyword arguments for the classifier 

479 """ 

480 

481 def predict(self, X): 

482 """ 

483 Predict the labels for a given set of features. 

484 

485 :param X: the features to predict on (array or recarry) 

486 :return: the predicted labels 

487 """ 

488 X = self._check_if_recarray(X) 

489 y_prob = self.pipe.predict_proba(X) 

490 y_pred = np.array([np.random.choice(len(prob), p=prob) for prob in y_prob]) 

491 

492 return y_pred 

493 

494 def predict_proba(self, X): 

495 """ 

496 Predict the probabilities for a given set of features. 

497 

498 :param X: the features to predict on (array or recarry) 

499 :return: the predicted probabilities 

500 """ 

501 X = self._check_if_recarray(X) 

502 y_prob = self.pipe.predict_proba(X) 

503 return y_prob 

504 

505 def predict_non_proba(self, X): 

506 """ 

507 Predict the class non-probabilistically for a given set of features. 

508 

509 :param X: the features to predict on (array or recarry) 

510 :return: the predicted probabilities 

511 """ 

512 X = self._check_if_recarray(X) 

513 y_pred = self.pipe.predict(X) 

514 return y_pred 

515 

516 def test(self, X_test, y_test, non_proba=False): 

517 y_pred = self.predict_non_proba(X_test) if non_proba else self.predict(X_test) 

518 y_prob = self.predict_proba(X_test) 

519 

520 test_arr = clf_diagnostics.setup_test(multi_class=True) 

521 clf_diagnostics.get_all_scores_multiclass(test_arr, y_test, y_pred, y_prob) 

522 test_arr = at.dict2rec(test_arr) 

523 self.test_scores = test_arr