Coverage for src/ufig/plugins/run_emulator.py: 100%
64 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 19:08 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 19:08 +0000
1# Copyright (C) 2023 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
5import h5py
6from cosmic_toolbox import arraytools as at
7from cosmic_toolbox import logger
8from ivy.plugin.base_plugin import BasePlugin
10HDF5_COMPRESS = {"compression": "gzip", "compression_opts": 9, "shuffle": True}
12LOGGER = logger.get_logger(__file__)
15class Plugin(BasePlugin):
16 def __call__(self):
17 LOGGER.info("Running the full emulator")
18 par = self.ctx.parameters
20 # load classifier and config
21 conf = par.emu_conf
22 clf_params = list(par.clf.params)
24 # load catalogs
25 cats = {}
26 for f in par.emu_filters:
27 cats[f] = h5py.File(par.det_clf_catalog_name_dict[f], "r")
29 # create classifier input
30 X = {}
31 for p in conf["input_band_dep"]:
32 for f in par.emu_filters:
33 X[p + f"_{f}"] = cats[f]["data"][p]
34 for p in conf["input_band_indep"]:
35 X[p] = cats[par.emu_filters[0]]["data"][p]
37 # Get ucat params that are not part of the emulator
38 if "x" in cats[par.emu_filters[0]]["data"].dtype.names:
39 x = cats[par.emu_filters[0]]["data"]["x"]
40 y = cats[par.emu_filters[0]]["data"]["y"]
41 radec = False
42 else:
43 radec = True
44 x = cats[par.emu_filters[0]]["data"]["ra"]
45 y = cats[par.emu_filters[0]]["data"]["dec"]
46 ids = cats[par.emu_filters[0]]["data"]["id"]
48 for f in par.emu_filters:
49 cats[f].close()
51 X = at.dict2rec(X)
52 # Reorder the fields to match the desired order
53 X = X[clf_params]
55 LOGGER.debug("Running the classifier")
56 det = par.clf.predict(X)
58 LOGGER.debug("Running the normalizing flow")
59 sexcat = par.nflow.sample(X[det])
60 select = at.get_finite_mask(sexcat)
61 LOGGER.debug("Writing the catalogs")
62 # write catalog catalog
63 for f in par.emu_filters:
64 cat = {}
65 for p in conf["input_band_dep"]:
66 cat[p] = sexcat[p + f"_{f}"]
67 for p in conf["input_band_indep"]:
68 cat[p] = sexcat[p]
69 for p in conf["output"]:
70 cat[p] = sexcat[p + f"_{f}"]
72 # positions for final catalog
73 if radec:
74 cat["ra"] = x[det]
75 cat["dec"] = y[det]
76 else:
77 cat["x"] = x[det]
78 cat["y"] = y[det]
80 cat["XWIN_IMAGE"] = x[det]
81 cat["YWIN_IMAGE"] = y[det]
82 cat["X_IMAGE"] = x[det]
83 cat["Y_IMAGE"] = y[det]
85 cat["id"] = ids[det]
87 cat = at.dict2rec(cat)
89 # add NO_MATCH_VAL to stars
90 # stars = cat["galaxy_type"] == -1
91 # cat["z"][stars] = NO_MATCH_VAL
93 # write catalog
94 filepath = par.sextractor_forced_photo_catalog_name_dict[f]
95 at.save_hdf_cols(filepath, cat[select], compression=HDF5_COMPRESS)
97 # del par.clf
98 # del par.nflow
99 del self.ctx.parameters.clf
100 del self.ctx.parameters.nflow
102 def __str__(self):
103 return "run the full emulator"