Coverage for src/edelweiss/clf_diagnostics.py: 100%
386 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
1# Copyright (C) 2023 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
5import os
6import pickle
8import matplotlib as mpl
9import matplotlib.pyplot as plt
10import numpy as np
11from cosmic_toolbox import arraytools as at
12from cosmic_toolbox import colors, file_utils
13from cosmic_toolbox.logger import get_logger
14from sklearn.calibration import CalibratedClassifierCV, calibration_curve
15from sklearn.metrics import (accuracy_score, average_precision_score,
16 brier_score_loss, f1_score, log_loss,
17 precision_recall_curve, precision_score,
18 recall_score, roc_auc_score, roc_curve)
19from sklearn.model_selection import GridSearchCV
21from edelweiss import clf_utils as utils
23LOGGER = get_logger(__file__)
24COL = colors.get_colors()
25colors.set_cycle()
28def get_confusion_matrix(y_true, y_pred):
29 """
30 Get the confusion matrix for the classifier.
32 :param y_true: true labels
33 :param y_pred: predicted labels
34 :return: True Positives, True Negatives, False Positives, False Negatives
35 """
37 tp = y_pred & y_true
38 tn = ~y_pred & ~y_true
39 fp = y_pred & ~y_true
40 fn = ~y_pred & y_true
41 return tp, tn, fp, fn
44def plot_hist_fp_fn_tp_tn(
45 param,
46 y_true,
47 y_pred,
48 output_directory=".",
49 clf="classifier",
50 final=False,
51 save_plot=False,
52):
53 """
54 Plot the stacked histogram of one parameter (e.g. i-band magnitude) for the
55 different confusion matrix elements.
57 :param param: parameter to plot
58 :param y_true: true labels
59 :param y_pred: predicted labels
60 :param output_directory: directory to save the plot
61 :param clf: classifier object or name of the classifier
62 :param final: if True, the plot is for the final classifier
63 :param save_plot: if True, save the plot
64 """
65 # TODO: something strange happens in the plotting.
66 tp, tn, fp, fn = get_confusion_matrix(y_true, y_pred)
67 fig = plt.figure()
68 plt.hist(
69 [param[fp], param[fn], param[tp], param[tn]],
70 bins=100,
71 stacked=True,
72 label=["FP", "FN", "TP", "TN"],
73 color=[COL["r"], COL["orange"], COL["g"], COL["b"]],
74 density=True,
75 )
76 plt.ylim(0, 0.2)
77 plt.legend()
78 plt.xlabel("i-band magnitude")
79 plt.ylabel("Normalized counts (stacked)")
80 name = get_name(clf, final=final)
81 path = os.path.join(output_directory, "clf/figures/")
82 file_utils.robust_makedirs(path)
83 path += f"stacked_hist_{name}"
84 if save_plot:
85 plt.savefig(path + ".pdf", bbox_inches="tight")
86 with open(path + ".pkl", "wb") as fh:
87 pickle.dump(fig, fh)
90def plot_hist_n_gal(
91 param,
92 y_true,
93 y_pred,
94 output_directory=".",
95 clf="classifier",
96 final=False,
97 save_plot=False,
98 fig=None,
99):
100 """
101 Plot the histogram of detected galaxies for the classifer and the true detected
102 galaxies for one parameter (e.g. i-band magnitude).
104 :param param: parameter to plot
105 :param y_true: true labels
106 :param y_pred: predicted labels
107 :param output_directory: directory to save the plot
108 :param clf: classifier object or name of the classifier
109 :param final: if True, the plot is for the final classifier
110 :param save_plot: if True, save the plot
111 :param fig: figure object, if None, create a new figure
112 """
113 # TODO: something strange happens in the plotting.
114 _, bins = np.histogram(param[y_true], bins=100)
116 if fig is None:
117 fig, ax = plt.subplots()
118 ax.set_xlabel("i-band magnitude")
119 ax.set_ylabel("galaxy counts")
120 ax.set_title("Number of detected galaxies")
121 ax.hist(param[y_true], bins=bins, label="true detected galaxies", color="grey")
122 else:
123 ax = fig.get_axes()[0]
124 name = get_name(clf, final=False)
125 ax.hist(param[y_pred], bins=bins, histtype="step", label=name)
126 ax.legend()
127 name = get_name(clf, final=final)
128 path = os.path.join(output_directory, "clf/figures/")
129 file_utils.robust_makedirs(path)
130 path += f"stacked_hist_{name}"
131 if save_plot:
132 plt.savefig(path + ".pdf", bbox_inches="tight")
133 with open(path + ".pkl", "wb") as fh:
134 pickle.dump(fig, fh)
135 return fig
138def plot_calibration_curve(
139 y_true,
140 y_prob,
141 output_directory=".",
142 clf="classifier",
143 final=False,
144 save_plot=False,
145 fig=None,
146):
147 """
148 Plot the calibration curve for the classifier.
150 :param y_true: true labels
151 :param y_prob: predicted probabilities
152 :param output_directory: directory to save the plot
153 :param clf: classifier object or name of the classifier
154 :param final: if True, the plot is for the final classifier
155 :param save_plot: if True, save the plot
156 :param fig: figure object, if None, create a new figure
157 """
158 if fig is None:
159 fig, ax = plt.subplots()
160 ax.plot([0, 1], [0, 1], ls="--", color="k", label="perfect calibration")
161 ax.set_xlabel("predicted probability")
162 ax.set_ylabel("fraction of positives")
163 ax.set_title("Calibration curve")
164 else:
165 ax = fig.get_axes()[0]
167 prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=20)
168 name = get_name(clf, final=False)
169 ax.plot(prob_pred, prob_true, label=name)
170 ax.legend()
171 name = get_name(clf, final=final)
172 path = os.path.join(output_directory, "clf/figures/")
173 file_utils.robust_makedirs(path)
174 path += f"stacked_hist_{name}"
175 if save_plot:
176 plt.savefig(path + ".pdf", bbox_inches="tight")
177 with open(path + ".pkl", "wb") as fh:
178 pickle.dump(fig, fh)
179 return fig
182def plot_roc_curve(
183 y_true,
184 y_prob,
185 output_directory=".",
186 clf="classifier",
187 final=False,
188 save_plot=False,
189 fig=None,
190):
191 """
192 Plot the ROC curve for the classifier.
194 :param y_true: true labels
195 :param y_prob: predicted probabilities
196 :param output_directory: directory to save the plot
197 :param clf: classifier object or name of the classifier
198 :param final: if True, the plot is for the final classifier
199 :param save_plot: if True, save the plot
200 :param fig: figure object, if None, create a new figure
201 """
202 fpr, tpr, _ = roc_curve(y_true, y_prob)
204 if fig is None:
205 fig, ax = plt.subplots()
206 ax.set_xlabel("false positive rate")
207 ax.set_ylabel("true positive rate")
208 ax.set_title("ROC curve")
209 ax.plot([0, 1], [0, 1], ls="--", color="k", label="random classifier")
210 ax.plot([0, 0, 1], [0, 1, 1], ls=":", color="k", label="perfect classifier")
211 else:
212 ax = fig.get_axes()[0]
213 ax.plot(fpr, tpr, label=get_name(clf, final=False))
214 ax.legend()
215 name = get_name(clf, final=final)
216 path = os.path.join(output_directory, "clf/figures/")
217 file_utils.robust_makedirs(path)
218 path += f"stacked_hist_{name}"
219 if save_plot:
220 plt.savefig(path + ".pdf", bbox_inches="tight")
221 with open(path + ".pkl", "wb") as fh:
222 pickle.dump(fig, fh)
223 return fig
226def plot_pr_curve(
227 y_true,
228 y_prob,
229 output_directory=".",
230 clf="classifier",
231 final=False,
232 save_plot=False,
233 fig=None,
234):
235 """
236 Plot the precision-recall curve for the classifier.
238 :param y_true: true labels
239 :param y_prob: predicted probabilities
240 :param output_directory: directory to save the plot
241 :param clf: classifier object or name of the classifier
242 :param final: if True, the plot is for the final classifier
243 :param save_plot: if True, save the plot
244 :param fig: figure object, if None, create a new figure
245 :return: figure object
246 """
248 precision, recall, _ = precision_recall_curve(y_true, y_prob)
250 if fig is None:
251 fig, ax = plt.subplots()
252 ax.plot([0, 1, 1], [1, 1, 0], ls=":", color="k", label="perfect classifier")
253 ax.set_xlabel("recall")
254 ax.set_ylabel("precision")
255 ax.set_title("Precision-Recall curve")
256 else:
257 ax = fig.get_axes()[0]
258 ax.plot(recall, precision, label=get_name(clf, final=False))
259 ax.legend()
260 name = get_name(clf, final=final)
261 path = os.path.join(output_directory, "clf/figures/")
262 file_utils.robust_makedirs(path)
263 path += f"stacked_hist_{name}"
264 if save_plot:
265 fig.savefig(path + ".pdf", bbox_inches="tight")
266 with open(path + ".pkl", "wb") as fh:
267 pickle.dump(fig, fh)
268 return fig
271def plot_spider_scores(
272 y_true,
273 y_pred,
274 y_prob,
275 output_directory=".",
276 clf="classifier",
277 final=False,
278 save_plot=False,
279 fig=None,
280 ranges=None,
281 print_scores=False,
282):
283 """
284 Plot the spider scores for the classifier.
286 :param y_true: true labels
287 :param y_pred: predicted labels
288 :param y_prob: predicted probabilities
289 :param output_directory: directory to save the plot
290 :param clf: classifier object or name of the classifier
291 :param final: if True, the plot is for the final classifier
292 :param save_plot: if True, save the plot
293 :param fig: figure object, if None, create a new figure
294 :param ranges: dictionary of ranges for each score
295 :param print_scores: if True, print the scores
296 :return: figure object
297 """
298 ranges = {} if ranges is None else ranges
299 test_scores = setup_test()
300 get_all_scores(test_scores, y_true, y_pred, y_prob)
301 for p in test_scores:
302 # remove the list structure
303 test_scores[p] = test_scores[p][0]
304 test_scores = at.dict2rec(test_scores)
305 test_scores = at.add_cols(test_scores, ["n_gal_deviation"])
306 test_scores["n_gal_deviation"] = (
307 test_scores["n_galaxies_true"] - test_scores["n_galaxies_pred"]
308 ) / test_scores["n_galaxies_true"]
309 test_scores = at.delete_columns(test_scores, ["n_galaxies_true", "n_galaxies_pred"])
310 if print_scores:
311 print(clf)
312 for score in test_scores.dtype.names:
313 print(f"{score}: {test_scores[score]}")
314 print("--------------------")
315 if fig is None:
316 fig, _ = plt.subplots(figsize=(10, 6), subplot_kw={"polar": True})
317 fig, ax = _plot_spider(fig, test_scores, clf, ranges=ranges)
318 ax.legend()
319 name = get_name(clf, final=final)
320 path = os.path.join(output_directory, "clf/figures/")
321 file_utils.robust_makedirs(path)
322 path += f"stacked_hist_{name}"
323 if save_plot:
324 fig.savefig(path + ".pdf", bbox_inches="tight")
325 with open(path + ".pkl", "wb") as fh:
326 pickle.dump(fig, fh)
327 return fig
330def _plot_spider(fig, data, label, ranges=None):
331 """
332 Plot the data in a spider plot.
334 :param fig: figure object
335 :param data: data to plot
336 :param label: label for the data
337 :param ranges: ranges for the different features
338 :return: figure object
339 """
341 ranges = {} if ranges is None else ranges
343 # Get the default ranges and update them with the given ranges
344 r = ranges.copy()
345 ranges = get_default_ranges_for_spider()
346 ranges.update(r)
348 # Prepare the data for the spider plot
349 data = scale_data_for_spider(data, ranges)
350 values, field_names = at.rec2arr(data, return_names=True)
351 values = values.flatten()
352 field_names = list(field_names)
353 add_range_to_name(field_names, ranges)
354 angles = np.linspace(0, 2 * np.pi, len(field_names), endpoint=False)
355 values = np.concatenate((values, [values[0]])) # Close the plot
356 angles = np.concatenate((angles, [angles[0]])) # Close the plot
358 # Plot the data
359 ax = fig.get_axes()[0]
360 ax.plot(angles, values, label=label)
362 # Add labels for each variable
363 ax.set_xticks(angles[:-1])
364 ax.set_xticklabels(field_names)
365 ax.set_rticks(np.linspace(0, 1, 5))
366 ax.set_rlim(0, 1)
367 ax.set_yticklabels([])
368 return fig, ax
371def scale_data_for_spider(data, ranges=None):
372 """
373 Scale the data for the spider plot such that the chosen range corresponds to the
374 0-1 range of the spider plot.
376 If the lower value of the range is higher than the upper value, the data is
377 inverted.
379 :param data: data to scale
380 :ranges: dictionary with the ranges for each variable, if a parameter is not in the
381 dictionary, the default range is (0, 1)
382 :return: scaled data
383 """
384 ranges = {} if ranges is None else ranges
385 for par in data.dtype.names:
386 try:
387 low, high = ranges[par]
388 except Exception:
389 low, high = 0, 1
391 if low > high:
392 data[par] = 1 - (data[par] - high) / (low - high)
394 else:
395 data[par] = (data[par] - low) / (high - low)
396 data[par] = np.clip(data[par], 0, 1)
397 return data
400def get_default_ranges_for_spider():
401 """
402 Get the default ranges for the spider plot.
404 :return: dictionary with the ranges for each variable
405 """
406 ranges = {
407 "accuracy": (0.5, 1),
408 "precision": (0.5, 1),
409 "recall": (0.5, 1),
410 "f1": (0.5, 1),
411 "n_gal_deviation": (-0.1, 0.1),
412 "auc_roc_score": (0.5, 1),
413 "log_loss_score": (0.5, 0),
414 "brier_score": (0.1, 0),
415 "auc_pr_score": (0.5, 1),
416 }
417 return ranges
420def add_range_to_name(field_names, ranges):
421 """
422 Add the range to the name of the variable such that the range is visible in the
423 spider plot.
425 :param field_names: list with the names of the variables
426 :param ranges: dictionary with the ranges for each variable
427 """
428 for i, f in enumerate(field_names):
429 field_names[i] = f + f": {ranges[f]}"
432def get_name(clf, final=False):
433 """
434 Get the name to add to the classifier
436 :param clf: classifier object (from sklearn) or name of the classifier
437 :param final: if True, the classifier was tested on the test data.
438 :return: name
439 """
440 from edelweiss.classifier import Classifier, MultiClassifier
442 name = ""
443 if isinstance(clf, str):
444 return clf
445 elif isinstance(clf, Classifier) | isinstance(clf, MultiClassifier):
446 return get_name(clf.pipe, final=final)
447 elif isinstance(clf, CalibratedClassifierCV):
448 clf_names = ["CalibratedClassifier"]
449 if isinstance(clf.estimator, GridSearchCV):
450 clf_names.extend(list(clf.estimator.estimator.named_steps.values()))
451 else:
452 clf_names.extend(list(clf.estimator.named_steps.values()))
453 else:
454 if isinstance(clf, GridSearchCV):
455 clf_names = list(clf.estimator.named_steps.values())
456 else:
457 try:
458 clf_names = list(clf.named_steps.values())
459 except Exception: # pragma: no cover
460 clf_names = ["clf"]
462 for n in clf_names:
463 name += str(n)[:7]
464 name += "_"
465 if final:
466 name = "final"
467 return name
470def plot_diagnostics(
471 clf,
472 X_test,
473 y_test,
474 output_directory=".",
475 final=False,
476 save_plot=False,
477 special_param="mag_i",
478):
479 """
480 Plot the diagnostics for the classifier.
482 :param clf: classifier object
483 :param X_test: test data
484 :param y_test: true labels
485 :param output_directory: directory to save the plots
486 :param final: if True, the classifier was tested on the test data.
487 :param save_plot: if True, save the plots
488 :param special_param: param to plot the histogram for
489 """
490 # Get the predictions
491 y_prob = clf.predict_proba(X_test)
492 y_pred = clf.predict(X_test)
494 # Make sure the labels are boolean
495 if not isinstance(y_pred[0], bool):
496 y_pred = y_pred.astype(bool)
497 if not isinstance(y_test[0], bool):
498 y_test = y_test.astype(bool)
500 # plot the diagnostics
501 param = X_test[special_param]
502 plot_hist_fp_fn_tp_tn(
503 param, y_test, y_pred, output_directory, clf, final=final, save_plot=save_plot
504 )
505 plot_hist_n_gal(
506 param, y_test, y_pred, output_directory, clf, final=final, save_plot=save_plot
507 )
508 plot_calibration_curve(
509 y_test, y_prob, output_directory, clf, final=final, save_plot=save_plot
510 )
511 plot_roc_curve(
512 y_test, y_prob, output_directory, clf, final=final, save_plot=save_plot
513 )
514 plot_pr_curve(
515 y_test, y_prob, output_directory, clf, final=final, save_plot=save_plot
516 )
519def plot_all_scores(scores, path_labels=None):
520 """
521 Plot all scores for the classifiers. Input can either be directly a recarray with
522 the scores or the path to the scores or a list of paths to the scores. If a list
523 is given, the scores of the different paths are combined and plotted with different
524 colors.
526 :param scores: recarray with the scores or path to the scores or list of paths
527 :param path_labels: list of labels for the different paths
528 """
530 # Load scores if path is given
531 if isinstance(scores, str):
532 # assuming path to main folder
533 scores = np.load(os.path.join(scores, "clf/test_scores.npy"))
534 colors = None
536 elif isinstance(scores, list):
537 # assuming list of paths to scores
538 scores = [np.load(os.path.join(path, "clf/test_scores.npy")) for path in scores]
539 colors = []
540 default_colors = list(COL.values())
541 for i, s in enumerate(scores):
542 # prepare colors for the different classifiers
543 colors += len(s) * [default_colors[i]]
544 scores = np.concatenate(scores)
546 else:
547 # assuming recarray
548 colors = None
550 # Setup the recarray
551 try:
552 names = scores["clf"]
553 except Exception:
554 names = len(scores) * ["classifier"]
555 names = np.array(names)
556 n_gal_true = scores["n_galaxies_true"][0]
557 scores = at.delete_columns(scores, ["clf", "n_galaxies_true"])
559 # Plot all scores
560 for param in scores.dtype.names:
561 data = scores[param]
563 # Sort classifiers
564 indices = np.argsort(data)
565 data = data[indices]
566 current_names = names[indices]
567 x = np.arange(len(current_names))
568 col = None if colors is None else np.array(colors)[indices]
569 fig_width = int(len(current_names) / 3)
570 plt.figure(figsize=(fig_width, 2))
571 plt.title(param)
573 if param == "n_galaxies_pred":
574 # Plot the difference to the true number of galaxies
575 y = data - n_gal_true
576 sign_switch_index = next(
577 (i for i, (y1, y2) in enumerate(zip(y, y[1:])) if y1 * y2 <= 0), None
578 )
579 if sign_switch_index is None:
580 sign_switch_index = len(y) if y[0] < 0 else -1
582 plt.bar(x, y, color=col)
583 plt.xticks(x, current_names)
584 plt.axvline(x=sign_switch_index + 0.5, color="k", ls="--")
586 # Update the ticks to actual galaxies
587 ticks, _ = plt.yticks()
588 new_ticks = [int(tick + n_gal_true) for tick in ticks]
589 plt.yticks(ticks, new_ticks)
591 else:
592 # Plot the scores
593 plt.bar(x, data, color=col)
594 plt.xticks(x, current_names)
595 # Set ylim for the scores to relevant parts
596 if (param != "log_loss_score") & (param != "brier_score"):
597 plt.ylim(0.5, 1)
599 plt.grid(axis="y")
600 plt.xticks(rotation=90)
602 if path_labels is not None:
603 # create legend with colors of the different classifiers
604 patches = []
605 for i, label in enumerate(path_labels):
606 patches.append(mpl.patches.Patch(color=default_colors[i], label=label))
607 plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left")
608 plt.xlim(-0.5, len(current_names) - 0.5)
611def plot_classifier_comparison(
612 clfs,
613 conf,
614 path,
615 spider_ranges=None,
616 labels=None,
617 print_scores=False,
618 special_param="mag_i",
619):
620 """
621 Plot the diagnostics for chosen classifiers. If the classifiers are not all from
622 same path, the conf and path parameters should be lists of the same length as clfs.
624 :param clfs: list of classifier names
625 :param conf: configuration dictionary or list of dictionaries
626 :param path: path to the data or list of paths
627 :param spider_ranges: dictionary with the ranges for the spider plot
628 :param labels: list of labels for the different paths
629 :param print_scores: if True, print the scores for the different classifiers
630 :param special_param: param to plot the histogram for
631 """
632 spider_ranges = {} if spider_ranges is None else spider_ranges
633 figs = [None, None, None, None, None]
634 if isinstance(path, list):
635 if not isinstance(conf, list):
636 conf = [conf] * len(path)
637 for i, p in enumerate(path):
638 label = labels[i] if labels is not None else None
639 _plot_classifier_comparison(
640 clfs[i],
641 conf[i],
642 p,
643 figs,
644 spider_ranges,
645 label,
646 print_scores,
647 special_param,
648 )
649 else:
650 _plot_classifier_comparison(
651 clfs, conf, path, figs, spider_ranges, labels, print_scores, special_param
652 )
655def _plot_classifier_comparison(
656 clfs,
657 conf,
658 path,
659 figs,
660 spider_ranges=None,
661 label=None,
662 print_scores=False,
663 special_param="mag_i",
664):
665 """
666 Plot the diagnostics for chosen classifiers.
668 :param clfs: list of classifier names
669 :param conf: configuration dictionary
670 :param path: path to the data
671 :param figs: list of figures
672 :param spider_ranges: dictionary with the ranges for the spider plot
673 :param labels: list of labels for the different paths
674 :param print_scores: if True, print the scores for the different classifiers
675 :param special_param: param to plot the histogram for
676 """
677 spider_ranges = {} if spider_ranges is None else spider_ranges
678 n_clfs = len(conf["classifier"])
679 n_scalers = len(conf["scaler"])
680 for index in range(n_clfs * n_scalers):
681 i_clf, i_scaler = np.unravel_index(index, (n_clfs, n_scalers))
682 clf_name = f"{conf['classifier'][i_clf]}_{conf['scaler'][i_scaler]}"
683 if clf_name in clfs:
684 name = label if label is not None else clf_name
685 _add_plot(
686 figs, index, name, path, spider_ranges, print_scores, special_param
687 )
690def plot_feature_importances(clf, clf_name="classifier", summed=False):
691 """
692 Plots the feature importances for the classifier.
694 :param clf: classifier object
695 :param names: names of the features
696 :param clf_name: name of the classifier
697 :param summed: if True, the summed feature importances are plotted
698 """
699 if clf.feature_importances is None:
700 LOGGER.warning("No feature importances found")
701 return
702 if summed:
703 feat_imp, par = at.rec2arr(clf.feature_importances, return_names=True)
704 else:
705 feat_imp, par = at.rec2arr(clf.summed_feature_importances, return_names=True)
706 par = np.array(list(par))
707 feat_imp = feat_imp.flatten()
709 plt.figure(figsize=(10, 5))
710 plt.title(f"Feature importance for {clf_name}")
711 plt.bar(par, feat_imp)
712 plt.xticks(rotation=90)
715def _add_plot(
716 figs,
717 index,
718 clf_name,
719 path=".",
720 spider_ranges=None,
721 print_scores=False,
722 special_param="mag_i",
723):
724 """
725 Add the plots for the classifier to the figure objects.
727 :param figs: list of figure objects
728 :param index: index of the classifier
729 :param clf_name: name of the classifier
730 :param path: path to the data
731 :param spider_ranges: dictionary with the ranges for the spider plot
732 :param print_scores: if True, print the scores for the different classifiers
733 :param special_param: param to plot the histogram for
734 :return: list of updated figure objects
735 """
736 spider_ranges = {} if spider_ranges is None else spider_ranges
737 # Load the data
738 from edelweiss.emulator import load_classifier
740 clf = load_classifier(path, utils.get_clf_name(index))
741 X_test = np.load(os.path.join(path, f"clf_cv/clf_test_data{index}.npy"))
742 y_true = np.load(os.path.join(path, f"clf_cv/clf_test_labels{index}.npy"))
744 # Get the predictions
745 y_prob = clf.predict_proba(X_test)
746 y_pred = clf.predict(X_test)
748 # Add the plots
749 figs[0] = plot_pr_curve(y_true, y_prob, clf=clf_name, fig=figs[0])
750 figs[1] = plot_roc_curve(y_true, y_prob, clf=clf_name, fig=figs[1])
751 figs[2] = plot_calibration_curve(y_true, y_prob, clf=clf_name, fig=figs[2])
752 figs[3] = plot_hist_n_gal(
753 X_test[special_param], y_true, y_pred, clf=clf_name, fig=figs[3]
754 )
755 figs[4] = plot_spider_scores(
756 y_true,
757 y_pred,
758 y_prob,
759 clf=clf_name,
760 fig=figs[4],
761 ranges=spider_ranges,
762 print_scores=print_scores,
763 )
764 plot_feature_importances(clf, clf.params, clf_name)
766 return figs
769def setup_test(multi_class=False):
770 """
771 Returns a dict where the test scores will be saved.
772 """
774 test = {}
775 test["accuracy"] = []
776 test["precision"] = []
777 test["recall"] = []
778 test["f1"] = []
779 if multi_class:
780 return test
781 test["n_galaxies_true"] = []
782 test["n_galaxies_pred"] = []
783 test["auc_roc_score"] = []
784 test["log_loss_score"] = []
785 test["brier_score"] = []
786 test["auc_pr_score"] = []
788 return test
791def get_all_scores(test_arr, y_test, y_pred, y_prob):
792 """
793 Calculates all the scores and append them to the test_arr dict
795 :param test_arr: dict where the test scores will be saved
796 :param y_test: test labels
797 :param y_pred: predicted labels
798 :param y_prob: probability of being detected
799 """
800 LOGGER.info("Test scores:")
801 LOGGER.info("------------")
802 # calculate accuracy score
803 accuracy = accuracy_score(y_test, y_pred)
804 test_arr["accuracy"].append(accuracy)
805 LOGGER.info(f"Accuracy: {accuracy}")
807 # calculate precision score
808 precision = precision_score(y_test, y_pred)
809 test_arr["precision"].append(precision)
810 LOGGER.info(f"Precision: {precision}")
812 # calculate recall score
813 recall = recall_score(y_test, y_pred)
814 test_arr["recall"].append(recall)
815 LOGGER.info(f"Recall: {recall}")
817 # calculate f1 score
818 f1 = f1_score(y_test, y_pred)
819 test_arr["f1"].append(f1)
820 LOGGER.info(f"F1 score: {f1}")
822 # calculate number of galaxies
823 n_galaxies_true = np.sum(y_test)
824 test_arr["n_galaxies_true"].append(n_galaxies_true)
825 n_galaxies_pred = np.sum(y_pred)
826 test_arr["n_galaxies_pred"].append(n_galaxies_pred)
827 LOGGER.info(f"Number of positives: {n_galaxies_pred} / {n_galaxies_true}")
829 # calculate roc auc score of probabilities
830 auc_roc_score = roc_auc_score(y_test, y_prob)
831 test_arr["auc_roc_score"].append(auc_roc_score)
832 LOGGER.info(f"ROC AUC score: {auc_roc_score}")
834 # calculate log loss
835 log_loss_score = log_loss(y_test, y_prob)
836 test_arr["log_loss_score"].append(log_loss_score)
837 LOGGER.info(f"Log loss score: {log_loss_score}")
839 # calculate brier score
840 brier_score = brier_score_loss(y_test, y_prob)
841 test_arr["brier_score"].append(brier_score)
842 LOGGER.info(f"Brier score: {brier_score}")
844 # calculate average precision score (AUC-PR)
845 auc_pr_score = average_precision_score(y_test, y_prob)
846 test_arr["auc_pr_score"].append(auc_pr_score)
847 LOGGER.info(f"AUC-PR score: {auc_pr_score}")
849 LOGGER.info("------------")
852def get_all_scores_multiclass(test_arr, y_test, y_pred, y_prob):
853 """
854 Calculates all the scores and append them to the test_arr dict for a multiclass
855 classifier.
857 :param test_arr: dict where the test scores will be saved
858 :param y_test: test labels
859 :param y_pred: predicted labels
860 :param y_prob: probability of being detected
861 """
862 LOGGER.info("Test scores:")
863 LOGGER.info("------------")
864 # calculate accuracy score
865 accuracy = accuracy_score(y_test, y_pred)
866 test_arr["accuracy"].append(accuracy)
867 LOGGER.info(f"Accuracy: {accuracy}")
869 # calculate precision score
870 precision = precision_score(y_test, y_pred, average="weighted")
871 test_arr["precision"].append(precision)
872 LOGGER.info(f"Precision: {precision}")
874 # calculate recall score
875 recall = recall_score(y_test, y_pred, average="weighted")
876 test_arr["recall"].append(recall)
877 LOGGER.info(f"Recall: {recall}")
879 # calculate f1 score
880 f1 = f1_score(y_test, y_pred, average="weighted")
881 test_arr["f1"].append(f1)
882 LOGGER.info(f"F1 score: {f1}")
884 LOGGER.info("------------")