Coverage for src/edelweiss/emulator.py: 100%
14 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
1# Copyright (C) 2023 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
5from cosmic_toolbox import logger
7from edelweiss.classifier import load_classifier, load_multiclassifier
8from edelweiss.nflow import load_nflow
10LOGGER = logger.get_logger(__file__)
13def load_emulator(
14 path,
15 bands=("g", "r", "i", "z", "y"),
16 multiclassifier=False,
17 subfolder_clf=None,
18 subfolder_nflow=None,
19):
20 """
21 Load an emulator from a given path. If bands is None, returns the classifier and
22 normalizing flow. If bands is not None, returns the classifier and a dictionary of
23 normalizing flows for each band.
25 :param path: path to the folder containing the emulator
26 :param bands: the bands to load (if None, assumes that there is only one nflow)
27 :param multiclassifier: whether to load a multiclassifier or not
28 :param subfolder_clf: subfolder of the emulator folder where the classifier is
29 stored
30 :param subfolder_nflow: subfolder of the emulator folder where the normalizing flow
31 is stored
32 :return: the loaded classifier and normalizing flow
33 """
34 if multiclassifier:
35 clf = load_multiclassifier(path, subfolder=subfolder_clf)
36 else:
37 clf = load_classifier(path, subfolder=subfolder_clf)
38 if bands is not None:
39 nflow = {}
40 for band in bands:
41 nflow[band] = load_nflow(path, band=band, subfolder=subfolder_nflow)
42 else:
43 nflow = load_nflow(path, subfolder=subfolder_nflow)
44 return clf, nflow