Coverage for src/edelweiss/emulator.py: 100%

14 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-18 17:09 +0000

1# Copyright (C) 2023 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4 

5from cosmic_toolbox import logger 

6 

7from edelweiss.classifier import load_classifier, load_multiclassifier 

8from edelweiss.nflow import load_nflow 

9 

10LOGGER = logger.get_logger(__file__) 

11 

12 

13def load_emulator( 

14 path, 

15 bands=("g", "r", "i", "z", "y"), 

16 multiclassifier=False, 

17 subfolder_clf=None, 

18 subfolder_nflow=None, 

19): 

20 """ 

21 Load an emulator from a given path. If bands is None, returns the classifier and 

22 normalizing flow. If bands is not None, returns the classifier and a dictionary of 

23 normalizing flows for each band. 

24 

25 :param path: path to the folder containing the emulator 

26 :param bands: the bands to load (if None, assumes that there is only one nflow) 

27 :param multiclassifier: whether to load a multiclassifier or not 

28 :param subfolder_clf: subfolder of the emulator folder where the classifier is 

29 stored 

30 :param subfolder_nflow: subfolder of the emulator folder where the normalizing flow 

31 is stored 

32 :return: the loaded classifier and normalizing flow 

33 """ 

34 if multiclassifier: 

35 clf = load_multiclassifier(path, subfolder=subfolder_clf) 

36 else: 

37 clf = load_classifier(path, subfolder=subfolder_clf) 

38 if bands is not None: 

39 nflow = {} 

40 for band in bands: 

41 nflow[band] = load_nflow(path, band=band, subfolder=subfolder_nflow) 

42 else: 

43 nflow = load_nflow(path, subfolder=subfolder_nflow) 

44 return clf, nflow