Coverage for src/edelweiss/clf_diagnostics.py: 100%

386 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-18 17:09 +0000

1# Copyright (C) 2023 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4 

5import os 

6import pickle 

7 

8import matplotlib as mpl 

9import matplotlib.pyplot as plt 

10import numpy as np 

11from cosmic_toolbox import arraytools as at 

12from cosmic_toolbox import colors, file_utils 

13from cosmic_toolbox.logger import get_logger 

14from sklearn.calibration import CalibratedClassifierCV, calibration_curve 

15from sklearn.metrics import (accuracy_score, average_precision_score, 

16 brier_score_loss, f1_score, log_loss, 

17 precision_recall_curve, precision_score, 

18 recall_score, roc_auc_score, roc_curve) 

19from sklearn.model_selection import GridSearchCV 

20 

21from edelweiss import clf_utils as utils 

22 

23LOGGER = get_logger(__file__) 

24COL = colors.get_colors() 

25colors.set_cycle() 

26 

27 

28def get_confusion_matrix(y_true, y_pred): 

29 """ 

30 Get the confusion matrix for the classifier. 

31 

32 :param y_true: true labels 

33 :param y_pred: predicted labels 

34 :return: True Positives, True Negatives, False Positives, False Negatives 

35 """ 

36 

37 tp = y_pred & y_true 

38 tn = ~y_pred & ~y_true 

39 fp = y_pred & ~y_true 

40 fn = ~y_pred & y_true 

41 return tp, tn, fp, fn 

42 

43 

44def plot_hist_fp_fn_tp_tn( 

45 param, 

46 y_true, 

47 y_pred, 

48 output_directory=".", 

49 clf="classifier", 

50 final=False, 

51 save_plot=False, 

52): 

53 """ 

54 Plot the stacked histogram of one parameter (e.g. i-band magnitude) for the 

55 different confusion matrix elements. 

56 

57 :param param: parameter to plot 

58 :param y_true: true labels 

59 :param y_pred: predicted labels 

60 :param output_directory: directory to save the plot 

61 :param clf: classifier object or name of the classifier 

62 :param final: if True, the plot is for the final classifier 

63 :param save_plot: if True, save the plot 

64 """ 

65 # TODO: something strange happens in the plotting. 

66 tp, tn, fp, fn = get_confusion_matrix(y_true, y_pred) 

67 fig = plt.figure() 

68 plt.hist( 

69 [param[fp], param[fn], param[tp], param[tn]], 

70 bins=100, 

71 stacked=True, 

72 label=["FP", "FN", "TP", "TN"], 

73 color=[COL["r"], COL["orange"], COL["g"], COL["b"]], 

74 density=True, 

75 ) 

76 plt.ylim(0, 0.2) 

77 plt.legend() 

78 plt.xlabel("i-band magnitude") 

79 plt.ylabel("Normalized counts (stacked)") 

80 name = get_name(clf, final=final) 

81 path = os.path.join(output_directory, "clf/figures/") 

82 file_utils.robust_makedirs(path) 

83 path += f"stacked_hist_{name}" 

84 if save_plot: 

85 plt.savefig(path + ".pdf", bbox_inches="tight") 

86 with open(path + ".pkl", "wb") as fh: 

87 pickle.dump(fig, fh) 

88 

89 

90def plot_hist_n_gal( 

91 param, 

92 y_true, 

93 y_pred, 

94 output_directory=".", 

95 clf="classifier", 

96 final=False, 

97 save_plot=False, 

98 fig=None, 

99): 

100 """ 

101 Plot the histogram of detected galaxies for the classifer and the true detected 

102 galaxies for one parameter (e.g. i-band magnitude). 

103 

104 :param param: parameter to plot 

105 :param y_true: true labels 

106 :param y_pred: predicted labels 

107 :param output_directory: directory to save the plot 

108 :param clf: classifier object or name of the classifier 

109 :param final: if True, the plot is for the final classifier 

110 :param save_plot: if True, save the plot 

111 :param fig: figure object, if None, create a new figure 

112 """ 

113 # TODO: something strange happens in the plotting. 

114 _, bins = np.histogram(param[y_true], bins=100) 

115 

116 if fig is None: 

117 fig, ax = plt.subplots() 

118 ax.set_xlabel("i-band magnitude") 

119 ax.set_ylabel("galaxy counts") 

120 ax.set_title("Number of detected galaxies") 

121 ax.hist(param[y_true], bins=bins, label="true detected galaxies", color="grey") 

122 else: 

123 ax = fig.get_axes()[0] 

124 name = get_name(clf, final=False) 

125 ax.hist(param[y_pred], bins=bins, histtype="step", label=name) 

126 ax.legend() 

127 name = get_name(clf, final=final) 

128 path = os.path.join(output_directory, "clf/figures/") 

129 file_utils.robust_makedirs(path) 

130 path += f"stacked_hist_{name}" 

131 if save_plot: 

132 plt.savefig(path + ".pdf", bbox_inches="tight") 

133 with open(path + ".pkl", "wb") as fh: 

134 pickle.dump(fig, fh) 

135 return fig 

136 

137 

138def plot_calibration_curve( 

139 y_true, 

140 y_prob, 

141 output_directory=".", 

142 clf="classifier", 

143 final=False, 

144 save_plot=False, 

145 fig=None, 

146): 

147 """ 

148 Plot the calibration curve for the classifier. 

149 

150 :param y_true: true labels 

151 :param y_prob: predicted probabilities 

152 :param output_directory: directory to save the plot 

153 :param clf: classifier object or name of the classifier 

154 :param final: if True, the plot is for the final classifier 

155 :param save_plot: if True, save the plot 

156 :param fig: figure object, if None, create a new figure 

157 """ 

158 if fig is None: 

159 fig, ax = plt.subplots() 

160 ax.plot([0, 1], [0, 1], ls="--", color="k", label="perfect calibration") 

161 ax.set_xlabel("predicted probability") 

162 ax.set_ylabel("fraction of positives") 

163 ax.set_title("Calibration curve") 

164 else: 

165 ax = fig.get_axes()[0] 

166 

167 prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=20) 

168 name = get_name(clf, final=False) 

169 ax.plot(prob_pred, prob_true, label=name) 

170 ax.legend() 

171 name = get_name(clf, final=final) 

172 path = os.path.join(output_directory, "clf/figures/") 

173 file_utils.robust_makedirs(path) 

174 path += f"stacked_hist_{name}" 

175 if save_plot: 

176 plt.savefig(path + ".pdf", bbox_inches="tight") 

177 with open(path + ".pkl", "wb") as fh: 

178 pickle.dump(fig, fh) 

179 return fig 

180 

181 

182def plot_roc_curve( 

183 y_true, 

184 y_prob, 

185 output_directory=".", 

186 clf="classifier", 

187 final=False, 

188 save_plot=False, 

189 fig=None, 

190): 

191 """ 

192 Plot the ROC curve for the classifier. 

193 

194 :param y_true: true labels 

195 :param y_prob: predicted probabilities 

196 :param output_directory: directory to save the plot 

197 :param clf: classifier object or name of the classifier 

198 :param final: if True, the plot is for the final classifier 

199 :param save_plot: if True, save the plot 

200 :param fig: figure object, if None, create a new figure 

201 """ 

202 fpr, tpr, _ = roc_curve(y_true, y_prob) 

203 

204 if fig is None: 

205 fig, ax = plt.subplots() 

206 ax.set_xlabel("false positive rate") 

207 ax.set_ylabel("true positive rate") 

208 ax.set_title("ROC curve") 

209 ax.plot([0, 1], [0, 1], ls="--", color="k", label="random classifier") 

210 ax.plot([0, 0, 1], [0, 1, 1], ls=":", color="k", label="perfect classifier") 

211 else: 

212 ax = fig.get_axes()[0] 

213 ax.plot(fpr, tpr, label=get_name(clf, final=False)) 

214 ax.legend() 

215 name = get_name(clf, final=final) 

216 path = os.path.join(output_directory, "clf/figures/") 

217 file_utils.robust_makedirs(path) 

218 path += f"stacked_hist_{name}" 

219 if save_plot: 

220 plt.savefig(path + ".pdf", bbox_inches="tight") 

221 with open(path + ".pkl", "wb") as fh: 

222 pickle.dump(fig, fh) 

223 return fig 

224 

225 

226def plot_pr_curve( 

227 y_true, 

228 y_prob, 

229 output_directory=".", 

230 clf="classifier", 

231 final=False, 

232 save_plot=False, 

233 fig=None, 

234): 

235 """ 

236 Plot the precision-recall curve for the classifier. 

237 

238 :param y_true: true labels 

239 :param y_prob: predicted probabilities 

240 :param output_directory: directory to save the plot 

241 :param clf: classifier object or name of the classifier 

242 :param final: if True, the plot is for the final classifier 

243 :param save_plot: if True, save the plot 

244 :param fig: figure object, if None, create a new figure 

245 :return: figure object 

246 """ 

247 

248 precision, recall, _ = precision_recall_curve(y_true, y_prob) 

249 

250 if fig is None: 

251 fig, ax = plt.subplots() 

252 ax.plot([0, 1, 1], [1, 1, 0], ls=":", color="k", label="perfect classifier") 

253 ax.set_xlabel("recall") 

254 ax.set_ylabel("precision") 

255 ax.set_title("Precision-Recall curve") 

256 else: 

257 ax = fig.get_axes()[0] 

258 ax.plot(recall, precision, label=get_name(clf, final=False)) 

259 ax.legend() 

260 name = get_name(clf, final=final) 

261 path = os.path.join(output_directory, "clf/figures/") 

262 file_utils.robust_makedirs(path) 

263 path += f"stacked_hist_{name}" 

264 if save_plot: 

265 fig.savefig(path + ".pdf", bbox_inches="tight") 

266 with open(path + ".pkl", "wb") as fh: 

267 pickle.dump(fig, fh) 

268 return fig 

269 

270 

271def plot_spider_scores( 

272 y_true, 

273 y_pred, 

274 y_prob, 

275 output_directory=".", 

276 clf="classifier", 

277 final=False, 

278 save_plot=False, 

279 fig=None, 

280 ranges=None, 

281 print_scores=False, 

282): 

283 """ 

284 Plot the spider scores for the classifier. 

285 

286 :param y_true: true labels 

287 :param y_pred: predicted labels 

288 :param y_prob: predicted probabilities 

289 :param output_directory: directory to save the plot 

290 :param clf: classifier object or name of the classifier 

291 :param final: if True, the plot is for the final classifier 

292 :param save_plot: if True, save the plot 

293 :param fig: figure object, if None, create a new figure 

294 :param ranges: dictionary of ranges for each score 

295 :param print_scores: if True, print the scores 

296 :return: figure object 

297 """ 

298 ranges = {} if ranges is None else ranges 

299 test_scores = setup_test() 

300 get_all_scores(test_scores, y_true, y_pred, y_prob) 

301 for p in test_scores: 

302 # remove the list structure 

303 test_scores[p] = test_scores[p][0] 

304 test_scores = at.dict2rec(test_scores) 

305 test_scores = at.add_cols(test_scores, ["n_gal_deviation"]) 

306 test_scores["n_gal_deviation"] = ( 

307 test_scores["n_galaxies_true"] - test_scores["n_galaxies_pred"] 

308 ) / test_scores["n_galaxies_true"] 

309 test_scores = at.delete_columns(test_scores, ["n_galaxies_true", "n_galaxies_pred"]) 

310 if print_scores: 

311 print(clf) 

312 for score in test_scores.dtype.names: 

313 print(f"{score}: {test_scores[score]}") 

314 print("--------------------") 

315 if fig is None: 

316 fig, _ = plt.subplots(figsize=(10, 6), subplot_kw={"polar": True}) 

317 fig, ax = _plot_spider(fig, test_scores, clf, ranges=ranges) 

318 ax.legend() 

319 name = get_name(clf, final=final) 

320 path = os.path.join(output_directory, "clf/figures/") 

321 file_utils.robust_makedirs(path) 

322 path += f"stacked_hist_{name}" 

323 if save_plot: 

324 fig.savefig(path + ".pdf", bbox_inches="tight") 

325 with open(path + ".pkl", "wb") as fh: 

326 pickle.dump(fig, fh) 

327 return fig 

328 

329 

330def _plot_spider(fig, data, label, ranges=None): 

331 """ 

332 Plot the data in a spider plot. 

333 

334 :param fig: figure object 

335 :param data: data to plot 

336 :param label: label for the data 

337 :param ranges: ranges for the different features 

338 :return: figure object 

339 """ 

340 

341 ranges = {} if ranges is None else ranges 

342 

343 # Get the default ranges and update them with the given ranges 

344 r = ranges.copy() 

345 ranges = get_default_ranges_for_spider() 

346 ranges.update(r) 

347 

348 # Prepare the data for the spider plot 

349 data = scale_data_for_spider(data, ranges) 

350 values, field_names = at.rec2arr(data, return_names=True) 

351 values = values.flatten() 

352 field_names = list(field_names) 

353 add_range_to_name(field_names, ranges) 

354 angles = np.linspace(0, 2 * np.pi, len(field_names), endpoint=False) 

355 values = np.concatenate((values, [values[0]])) # Close the plot 

356 angles = np.concatenate((angles, [angles[0]])) # Close the plot 

357 

358 # Plot the data 

359 ax = fig.get_axes()[0] 

360 ax.plot(angles, values, label=label) 

361 

362 # Add labels for each variable 

363 ax.set_xticks(angles[:-1]) 

364 ax.set_xticklabels(field_names) 

365 ax.set_rticks(np.linspace(0, 1, 5)) 

366 ax.set_rlim(0, 1) 

367 ax.set_yticklabels([]) 

368 return fig, ax 

369 

370 

371def scale_data_for_spider(data, ranges=None): 

372 """ 

373 Scale the data for the spider plot such that the chosen range corresponds to the 

374 0-1 range of the spider plot. 

375 

376 If the lower value of the range is higher than the upper value, the data is 

377 inverted. 

378 

379 :param data: data to scale 

380 :ranges: dictionary with the ranges for each variable, if a parameter is not in the 

381 dictionary, the default range is (0, 1) 

382 :return: scaled data 

383 """ 

384 ranges = {} if ranges is None else ranges 

385 for par in data.dtype.names: 

386 try: 

387 low, high = ranges[par] 

388 except Exception: 

389 low, high = 0, 1 

390 

391 if low > high: 

392 data[par] = 1 - (data[par] - high) / (low - high) 

393 

394 else: 

395 data[par] = (data[par] - low) / (high - low) 

396 data[par] = np.clip(data[par], 0, 1) 

397 return data 

398 

399 

400def get_default_ranges_for_spider(): 

401 """ 

402 Get the default ranges for the spider plot. 

403 

404 :return: dictionary with the ranges for each variable 

405 """ 

406 ranges = { 

407 "accuracy": (0.5, 1), 

408 "precision": (0.5, 1), 

409 "recall": (0.5, 1), 

410 "f1": (0.5, 1), 

411 "n_gal_deviation": (-0.1, 0.1), 

412 "auc_roc_score": (0.5, 1), 

413 "log_loss_score": (0.5, 0), 

414 "brier_score": (0.1, 0), 

415 "auc_pr_score": (0.5, 1), 

416 } 

417 return ranges 

418 

419 

420def add_range_to_name(field_names, ranges): 

421 """ 

422 Add the range to the name of the variable such that the range is visible in the 

423 spider plot. 

424 

425 :param field_names: list with the names of the variables 

426 :param ranges: dictionary with the ranges for each variable 

427 """ 

428 for i, f in enumerate(field_names): 

429 field_names[i] = f + f": {ranges[f]}" 

430 

431 

432def get_name(clf, final=False): 

433 """ 

434 Get the name to add to the classifier 

435 

436 :param clf: classifier object (from sklearn) or name of the classifier 

437 :param final: if True, the classifier was tested on the test data. 

438 :return: name 

439 """ 

440 from edelweiss.classifier import Classifier, MultiClassifier 

441 

442 name = "" 

443 if isinstance(clf, str): 

444 return clf 

445 elif isinstance(clf, Classifier) | isinstance(clf, MultiClassifier): 

446 return get_name(clf.pipe, final=final) 

447 elif isinstance(clf, CalibratedClassifierCV): 

448 clf_names = ["CalibratedClassifier"] 

449 if isinstance(clf.estimator, GridSearchCV): 

450 clf_names.extend(list(clf.estimator.estimator.named_steps.values())) 

451 else: 

452 clf_names.extend(list(clf.estimator.named_steps.values())) 

453 else: 

454 if isinstance(clf, GridSearchCV): 

455 clf_names = list(clf.estimator.named_steps.values()) 

456 else: 

457 try: 

458 clf_names = list(clf.named_steps.values()) 

459 except Exception: # pragma: no cover 

460 clf_names = ["clf"] 

461 

462 for n in clf_names: 

463 name += str(n)[:7] 

464 name += "_" 

465 if final: 

466 name = "final" 

467 return name 

468 

469 

470def plot_diagnostics( 

471 clf, 

472 X_test, 

473 y_test, 

474 output_directory=".", 

475 final=False, 

476 save_plot=False, 

477 special_param="mag_i", 

478): 

479 """ 

480 Plot the diagnostics for the classifier. 

481 

482 :param clf: classifier object 

483 :param X_test: test data 

484 :param y_test: true labels 

485 :param output_directory: directory to save the plots 

486 :param final: if True, the classifier was tested on the test data. 

487 :param save_plot: if True, save the plots 

488 :param special_param: param to plot the histogram for 

489 """ 

490 # Get the predictions 

491 y_prob = clf.predict_proba(X_test) 

492 y_pred = clf.predict(X_test) 

493 

494 # Make sure the labels are boolean 

495 if not isinstance(y_pred[0], bool): 

496 y_pred = y_pred.astype(bool) 

497 if not isinstance(y_test[0], bool): 

498 y_test = y_test.astype(bool) 

499 

500 # plot the diagnostics 

501 param = X_test[special_param] 

502 plot_hist_fp_fn_tp_tn( 

503 param, y_test, y_pred, output_directory, clf, final=final, save_plot=save_plot 

504 ) 

505 plot_hist_n_gal( 

506 param, y_test, y_pred, output_directory, clf, final=final, save_plot=save_plot 

507 ) 

508 plot_calibration_curve( 

509 y_test, y_prob, output_directory, clf, final=final, save_plot=save_plot 

510 ) 

511 plot_roc_curve( 

512 y_test, y_prob, output_directory, clf, final=final, save_plot=save_plot 

513 ) 

514 plot_pr_curve( 

515 y_test, y_prob, output_directory, clf, final=final, save_plot=save_plot 

516 ) 

517 

518 

519def plot_all_scores(scores, path_labels=None): 

520 """ 

521 Plot all scores for the classifiers. Input can either be directly a recarray with 

522 the scores or the path to the scores or a list of paths to the scores. If a list 

523 is given, the scores of the different paths are combined and plotted with different 

524 colors. 

525 

526 :param scores: recarray with the scores or path to the scores or list of paths 

527 :param path_labels: list of labels for the different paths 

528 """ 

529 

530 # Load scores if path is given 

531 if isinstance(scores, str): 

532 # assuming path to main folder 

533 scores = np.load(os.path.join(scores, "clf/test_scores.npy")) 

534 colors = None 

535 

536 elif isinstance(scores, list): 

537 # assuming list of paths to scores 

538 scores = [np.load(os.path.join(path, "clf/test_scores.npy")) for path in scores] 

539 colors = [] 

540 default_colors = list(COL.values()) 

541 for i, s in enumerate(scores): 

542 # prepare colors for the different classifiers 

543 colors += len(s) * [default_colors[i]] 

544 scores = np.concatenate(scores) 

545 

546 else: 

547 # assuming recarray 

548 colors = None 

549 

550 # Setup the recarray 

551 try: 

552 names = scores["clf"] 

553 except Exception: 

554 names = len(scores) * ["classifier"] 

555 names = np.array(names) 

556 n_gal_true = scores["n_galaxies_true"][0] 

557 scores = at.delete_columns(scores, ["clf", "n_galaxies_true"]) 

558 

559 # Plot all scores 

560 for param in scores.dtype.names: 

561 data = scores[param] 

562 

563 # Sort classifiers 

564 indices = np.argsort(data) 

565 data = data[indices] 

566 current_names = names[indices] 

567 x = np.arange(len(current_names)) 

568 col = None if colors is None else np.array(colors)[indices] 

569 fig_width = int(len(current_names) / 3) 

570 plt.figure(figsize=(fig_width, 2)) 

571 plt.title(param) 

572 

573 if param == "n_galaxies_pred": 

574 # Plot the difference to the true number of galaxies 

575 y = data - n_gal_true 

576 sign_switch_index = next( 

577 (i for i, (y1, y2) in enumerate(zip(y, y[1:])) if y1 * y2 <= 0), None 

578 ) 

579 if sign_switch_index is None: 

580 sign_switch_index = len(y) if y[0] < 0 else -1 

581 

582 plt.bar(x, y, color=col) 

583 plt.xticks(x, current_names) 

584 plt.axvline(x=sign_switch_index + 0.5, color="k", ls="--") 

585 

586 # Update the ticks to actual galaxies 

587 ticks, _ = plt.yticks() 

588 new_ticks = [int(tick + n_gal_true) for tick in ticks] 

589 plt.yticks(ticks, new_ticks) 

590 

591 else: 

592 # Plot the scores 

593 plt.bar(x, data, color=col) 

594 plt.xticks(x, current_names) 

595 # Set ylim for the scores to relevant parts 

596 if (param != "log_loss_score") & (param != "brier_score"): 

597 plt.ylim(0.5, 1) 

598 

599 plt.grid(axis="y") 

600 plt.xticks(rotation=90) 

601 

602 if path_labels is not None: 

603 # create legend with colors of the different classifiers 

604 patches = [] 

605 for i, label in enumerate(path_labels): 

606 patches.append(mpl.patches.Patch(color=default_colors[i], label=label)) 

607 plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left") 

608 plt.xlim(-0.5, len(current_names) - 0.5) 

609 

610 

611def plot_classifier_comparison( 

612 clfs, 

613 conf, 

614 path, 

615 spider_ranges=None, 

616 labels=None, 

617 print_scores=False, 

618 special_param="mag_i", 

619): 

620 """ 

621 Plot the diagnostics for chosen classifiers. If the classifiers are not all from 

622 same path, the conf and path parameters should be lists of the same length as clfs. 

623 

624 :param clfs: list of classifier names 

625 :param conf: configuration dictionary or list of dictionaries 

626 :param path: path to the data or list of paths 

627 :param spider_ranges: dictionary with the ranges for the spider plot 

628 :param labels: list of labels for the different paths 

629 :param print_scores: if True, print the scores for the different classifiers 

630 :param special_param: param to plot the histogram for 

631 """ 

632 spider_ranges = {} if spider_ranges is None else spider_ranges 

633 figs = [None, None, None, None, None] 

634 if isinstance(path, list): 

635 if not isinstance(conf, list): 

636 conf = [conf] * len(path) 

637 for i, p in enumerate(path): 

638 label = labels[i] if labels is not None else None 

639 _plot_classifier_comparison( 

640 clfs[i], 

641 conf[i], 

642 p, 

643 figs, 

644 spider_ranges, 

645 label, 

646 print_scores, 

647 special_param, 

648 ) 

649 else: 

650 _plot_classifier_comparison( 

651 clfs, conf, path, figs, spider_ranges, labels, print_scores, special_param 

652 ) 

653 

654 

655def _plot_classifier_comparison( 

656 clfs, 

657 conf, 

658 path, 

659 figs, 

660 spider_ranges=None, 

661 label=None, 

662 print_scores=False, 

663 special_param="mag_i", 

664): 

665 """ 

666 Plot the diagnostics for chosen classifiers. 

667 

668 :param clfs: list of classifier names 

669 :param conf: configuration dictionary 

670 :param path: path to the data 

671 :param figs: list of figures 

672 :param spider_ranges: dictionary with the ranges for the spider plot 

673 :param labels: list of labels for the different paths 

674 :param print_scores: if True, print the scores for the different classifiers 

675 :param special_param: param to plot the histogram for 

676 """ 

677 spider_ranges = {} if spider_ranges is None else spider_ranges 

678 n_clfs = len(conf["classifier"]) 

679 n_scalers = len(conf["scaler"]) 

680 for index in range(n_clfs * n_scalers): 

681 i_clf, i_scaler = np.unravel_index(index, (n_clfs, n_scalers)) 

682 clf_name = f"{conf['classifier'][i_clf]}_{conf['scaler'][i_scaler]}" 

683 if clf_name in clfs: 

684 name = label if label is not None else clf_name 

685 _add_plot( 

686 figs, index, name, path, spider_ranges, print_scores, special_param 

687 ) 

688 

689 

690def plot_feature_importances(clf, clf_name="classifier", summed=False): 

691 """ 

692 Plots the feature importances for the classifier. 

693 

694 :param clf: classifier object 

695 :param names: names of the features 

696 :param clf_name: name of the classifier 

697 :param summed: if True, the summed feature importances are plotted 

698 """ 

699 if clf.feature_importances is None: 

700 LOGGER.warning("No feature importances found") 

701 return 

702 if summed: 

703 feat_imp, par = at.rec2arr(clf.feature_importances, return_names=True) 

704 else: 

705 feat_imp, par = at.rec2arr(clf.summed_feature_importances, return_names=True) 

706 par = np.array(list(par)) 

707 feat_imp = feat_imp.flatten() 

708 

709 plt.figure(figsize=(10, 5)) 

710 plt.title(f"Feature importance for {clf_name}") 

711 plt.bar(par, feat_imp) 

712 plt.xticks(rotation=90) 

713 

714 

715def _add_plot( 

716 figs, 

717 index, 

718 clf_name, 

719 path=".", 

720 spider_ranges=None, 

721 print_scores=False, 

722 special_param="mag_i", 

723): 

724 """ 

725 Add the plots for the classifier to the figure objects. 

726 

727 :param figs: list of figure objects 

728 :param index: index of the classifier 

729 :param clf_name: name of the classifier 

730 :param path: path to the data 

731 :param spider_ranges: dictionary with the ranges for the spider plot 

732 :param print_scores: if True, print the scores for the different classifiers 

733 :param special_param: param to plot the histogram for 

734 :return: list of updated figure objects 

735 """ 

736 spider_ranges = {} if spider_ranges is None else spider_ranges 

737 # Load the data 

738 from edelweiss.emulator import load_classifier 

739 

740 clf = load_classifier(path, utils.get_clf_name(index)) 

741 X_test = np.load(os.path.join(path, f"clf_cv/clf_test_data{index}.npy")) 

742 y_true = np.load(os.path.join(path, f"clf_cv/clf_test_labels{index}.npy")) 

743 

744 # Get the predictions 

745 y_prob = clf.predict_proba(X_test) 

746 y_pred = clf.predict(X_test) 

747 

748 # Add the plots 

749 figs[0] = plot_pr_curve(y_true, y_prob, clf=clf_name, fig=figs[0]) 

750 figs[1] = plot_roc_curve(y_true, y_prob, clf=clf_name, fig=figs[1]) 

751 figs[2] = plot_calibration_curve(y_true, y_prob, clf=clf_name, fig=figs[2]) 

752 figs[3] = plot_hist_n_gal( 

753 X_test[special_param], y_true, y_pred, clf=clf_name, fig=figs[3] 

754 ) 

755 figs[4] = plot_spider_scores( 

756 y_true, 

757 y_pred, 

758 y_prob, 

759 clf=clf_name, 

760 fig=figs[4], 

761 ranges=spider_ranges, 

762 print_scores=print_scores, 

763 ) 

764 plot_feature_importances(clf, clf.params, clf_name) 

765 

766 return figs 

767 

768 

769def setup_test(multi_class=False): 

770 """ 

771 Returns a dict where the test scores will be saved. 

772 """ 

773 

774 test = {} 

775 test["accuracy"] = [] 

776 test["precision"] = [] 

777 test["recall"] = [] 

778 test["f1"] = [] 

779 if multi_class: 

780 return test 

781 test["n_galaxies_true"] = [] 

782 test["n_galaxies_pred"] = [] 

783 test["auc_roc_score"] = [] 

784 test["log_loss_score"] = [] 

785 test["brier_score"] = [] 

786 test["auc_pr_score"] = [] 

787 

788 return test 

789 

790 

791def get_all_scores(test_arr, y_test, y_pred, y_prob): 

792 """ 

793 Calculates all the scores and append them to the test_arr dict 

794 

795 :param test_arr: dict where the test scores will be saved 

796 :param y_test: test labels 

797 :param y_pred: predicted labels 

798 :param y_prob: probability of being detected 

799 """ 

800 LOGGER.info("Test scores:") 

801 LOGGER.info("------------") 

802 # calculate accuracy score 

803 accuracy = accuracy_score(y_test, y_pred) 

804 test_arr["accuracy"].append(accuracy) 

805 LOGGER.info(f"Accuracy: {accuracy}") 

806 

807 # calculate precision score 

808 precision = precision_score(y_test, y_pred) 

809 test_arr["precision"].append(precision) 

810 LOGGER.info(f"Precision: {precision}") 

811 

812 # calculate recall score 

813 recall = recall_score(y_test, y_pred) 

814 test_arr["recall"].append(recall) 

815 LOGGER.info(f"Recall: {recall}") 

816 

817 # calculate f1 score 

818 f1 = f1_score(y_test, y_pred) 

819 test_arr["f1"].append(f1) 

820 LOGGER.info(f"F1 score: {f1}") 

821 

822 # calculate number of galaxies 

823 n_galaxies_true = np.sum(y_test) 

824 test_arr["n_galaxies_true"].append(n_galaxies_true) 

825 n_galaxies_pred = np.sum(y_pred) 

826 test_arr["n_galaxies_pred"].append(n_galaxies_pred) 

827 LOGGER.info(f"Number of positives: {n_galaxies_pred} / {n_galaxies_true}") 

828 

829 # calculate roc auc score of probabilities 

830 auc_roc_score = roc_auc_score(y_test, y_prob) 

831 test_arr["auc_roc_score"].append(auc_roc_score) 

832 LOGGER.info(f"ROC AUC score: {auc_roc_score}") 

833 

834 # calculate log loss 

835 log_loss_score = log_loss(y_test, y_prob) 

836 test_arr["log_loss_score"].append(log_loss_score) 

837 LOGGER.info(f"Log loss score: {log_loss_score}") 

838 

839 # calculate brier score 

840 brier_score = brier_score_loss(y_test, y_prob) 

841 test_arr["brier_score"].append(brier_score) 

842 LOGGER.info(f"Brier score: {brier_score}") 

843 

844 # calculate average precision score (AUC-PR) 

845 auc_pr_score = average_precision_score(y_test, y_prob) 

846 test_arr["auc_pr_score"].append(auc_pr_score) 

847 LOGGER.info(f"AUC-PR score: {auc_pr_score}") 

848 

849 LOGGER.info("------------") 

850 

851 

852def get_all_scores_multiclass(test_arr, y_test, y_pred, y_prob): 

853 """ 

854 Calculates all the scores and append them to the test_arr dict for a multiclass 

855 classifier. 

856 

857 :param test_arr: dict where the test scores will be saved 

858 :param y_test: test labels 

859 :param y_pred: predicted labels 

860 :param y_prob: probability of being detected 

861 """ 

862 LOGGER.info("Test scores:") 

863 LOGGER.info("------------") 

864 # calculate accuracy score 

865 accuracy = accuracy_score(y_test, y_pred) 

866 test_arr["accuracy"].append(accuracy) 

867 LOGGER.info(f"Accuracy: {accuracy}") 

868 

869 # calculate precision score 

870 precision = precision_score(y_test, y_pred, average="weighted") 

871 test_arr["precision"].append(precision) 

872 LOGGER.info(f"Precision: {precision}") 

873 

874 # calculate recall score 

875 recall = recall_score(y_test, y_pred, average="weighted") 

876 test_arr["recall"].append(recall) 

877 LOGGER.info(f"Recall: {recall}") 

878 

879 # calculate f1 score 

880 f1 = f1_score(y_test, y_pred, average="weighted") 

881 test_arr["f1"].append(f1) 

882 LOGGER.info(f"F1 score: {f1}") 

883 

884 LOGGER.info("------------")