Coverage for src/ufig/plugins/run_detection_classifier.py: 100%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 19:08 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 19:08 +0000
1# Copyright (C) 2023 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
5import h5py
6from cosmic_toolbox import arraytools as at
7from cosmic_toolbox import file_utils, logger
8from ivy.plugin.base_plugin import BasePlugin
10HDF5_COMPRESS = {"compression": "gzip", "compression_opts": 9, "shuffle": True}
12LOGGER = logger.get_logger(__file__)
15class Plugin(BasePlugin):
16 def __call__(self):
17 par = self.ctx.parameters
19 # load classifier and config
20 clf = par.clf
21 conf = par.emu_conf
22 clf_params = list(clf.params)
24 # load catalogs
25 cats = {}
26 for f in par.emu_filters:
27 cats[f] = h5py.File(par.det_clf_catalog_name_dict[f], "r")
29 # create classifier input
30 X = {}
31 for p in conf["input_band_dep"]:
32 for f in par.emu_filters:
33 X[p + f"_{f}"] = cats[f]["data"][p]
34 for p in conf["input_band_indep"]:
35 X[p] = cats[par.filters[0]]["data"][p]
37 # Get positions
38 x = cats[par.filters[0]]["data"]["x"]
39 y = cats[par.filters[0]]["data"]["y"]
41 # close catalogs
42 for f in par.emu_filters:
43 cats[f].close()
45 # check if parameters match
46 X = at.dict2rec(X)
47 # Reorder the fields to match the desired order
48 X = X[clf_params]
50 # run classifier
51 det = clf.predict(X)
53 # write detection catalog
54 for f in par.emu_filters:
55 cat = {}
56 for p in conf["input_band_dep"]:
57 cat[p] = X[p + f"_{f}"][det]
58 for p in conf["input_band_indep"]:
59 cat[p] = X[p][det]
61 # positions for final catalog
62 cat["x"] = x[det]
63 cat["y"] = y[det]
65 cat = at.dict2rec(cat)
66 file_utils.write_to_hdf(
67 par.galaxy_catalog_name_dict[f], cat, "data", **HDF5_COMPRESS
68 )
69 LOGGER.info("Saved detection catalogs")
71 def __str__(self):
72 return "run detection classifier"