Coverage for src/ufig/plugins/run_nflow.py: 92%
38 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 19:08 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-12 19:08 +0000
1# Copyright (C) 2023 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
5import contextlib
7import h5py
8from cosmic_toolbox import arraytools as at
9from cosmic_toolbox import logger
10from ivy.plugin.base_plugin import BasePlugin
12from ufig.plugins.match_sextractor_catalog_multiband_read import NO_MATCH_VAL
13from ufig.plugins.write_catalog_for_emu import enrich_star_catalog
15HDF5_COMPRESS = {"compression": "gzip", "compression_opts": 9, "shuffle": True}
17LOGGER = logger.get_logger(__file__)
20class Plugin(BasePlugin):
21 def __call__(self):
22 par = self.ctx.parameters
24 filter = self.ctx.current_filter
26 # get classifier and config
27 nflow = par.nflow[filter]
29 # load catalogs
30 with h5py.File(par.galaxy_catalog_name_dict[filter], "r") as fh5:
31 cat = fh5["data"][:]
33 # remove x and y
34 x = cat["x"]
35 y = cat["y"]
36 cat = at.delete_columns(cat, ["x", "y"])
38 sexcat = nflow.sample(cat)
40 # add x and y and pretend they are measured by sextractor
41 sexcat = at.add_cols(sexcat, ["XWIN_IMAGE", "YWIN_IMAGE", "X_IMAGE", "Y_IMAGE"])
42 sexcat["XWIN_IMAGE"] = x
43 sexcat["YWIN_IMAGE"] = y
44 sexcat["X_IMAGE"] = x
45 sexcat["Y_IMAGE"] = y
47 # add NO_MATCH_VAL to stars
48 _, enriched_params1 = enrich_star_catalog(cat=None, par=par)
49 _, enriched_params2 = enrich_star_catalog(cat=None, par=par)
50 stars = sexcat["galaxy_type"] == -1
51 enriched_params = enriched_params1 + enriched_params2
52 # TODO: for now don't do it
53 enriched_params = []
54 for p in enriched_params: 54 ↛ 55line 54 didn't jump to line 55 because the loop on line 54 never started
55 with contextlib.suppress(KeyError):
56 sexcat[p][stars] = NO_MATCH_VAL
57 sexcat["z"][stars] = NO_MATCH_VAL
59 # write catalog
60 filepath = par.sextractor_forced_photo_catalog_name_dict[
61 self.ctx.current_filter
62 ]
63 at.save_hdf_cols(filepath, sexcat, compression=HDF5_COMPRESS)
65 def __str__(self):
66 return "run the normalizing flow"