Coverage for src/galsbi/ucat/plugins/write_catalog.py: 100%
91 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-26 10:41 +0000
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-26 10:41 +0000
1# Copyright (C) 2019 ETH Zurich, Institute for Particle Physics and Astrophysics
3"""
4Created on Aug 2021
5author: Tomasz Kacprzak
6"""
8import h5py
9import numpy as np
10from cosmic_toolbox import arraytools as at
11from cosmic_toolbox import logger
12from ivy.plugin.base_plugin import BasePlugin
14LOGGER = logger.get_logger(__file__)
17def catalog_to_rec(catalog):
18 # get dtype first
19 dtype_list = []
20 for col_name in catalog.columns:
21 col = getattr(catalog, col_name)
22 n_obj = len(col)
23 if len(col.shape) == 1:
24 dtype_list += [(col_name, col.dtype)]
25 else:
26 dtype_list += [(col_name, col.dtype, col.shape[1])]
28 # create empty array
29 rec = np.empty(n_obj, dtype=np.dtype(dtype_list))
31 # copy columns to array
32 for col_name in catalog.columns:
33 col = getattr(catalog, col_name)
34 if len(col.shape) == 1:
35 rec[col_name] = col
36 elif col.shape[1] == 1:
37 rec[col_name] = col.ravel()
38 else:
39 rec[col_name] = col
41 return rec
44def sed_catalog_to_rec(catalog):
45 """
46 To save the SEDs, we don't need the full catalog, but only the ID, SED and redshift.
47 The ID is necessary to link the SEDs to the galaxies in the photometric catalog.
48 The redshift is necessary to adapt the wavelengths of the SEDs to the observed frame
49 """
50 columns = ["id", "sed", "z"]
51 dtype_list = []
52 for col in columns:
53 arr = getattr(catalog, col)
54 if len(arr.shape) == 1:
55 dtype_list.append((col, arr.dtype))
56 else:
57 dtype_list.append((col, arr.dtype, arr.shape[1]))
59 n_obj = len(getattr(catalog, columns[0]))
60 rec = np.empty(n_obj, dtype=np.dtype(dtype_list))
61 for col in columns:
62 arr = getattr(catalog, col)
63 if len(arr.shape) == 1:
64 rec[col] = arr
65 elif arr.shape[1] == 1:
66 rec[col] = arr.ravel()
67 else:
68 rec[col] = arr
69 return rec
72def save_sed(filepath_out, cat, restframe_wavelength_in_A):
73 """
74 Save the SEDs to a file.
75 The SEDs are saved in the rest frame, so we need to adapt the wavelenghts to the
76 observed frame.
77 """
78 with h5py.File(filepath_out, "w") as f:
79 f.create_dataset("data", data=cat)
80 f.create_dataset("restframe_wavelength_in_A", data=restframe_wavelength_in_A)
83class Plugin(BasePlugin):
84 def __call__(self):
85 par = self.ctx.parameters
87 if hasattr(self.ctx, "current_filter"):
88 # if the current filter is set, use it
89 # this is the case for the image generation
90 f = self.ctx.current_filter
91 save_seds_if_requested = f == par.reference_band
92 else:
93 save_seds_if_requested = True
95 # write catalogs
96 if "galaxies" in self.ctx:
97 filepath_out = par.galaxy_catalog_name
98 cat = catalog_to_rec(self.ctx.galaxies)
100 cat = self.enrich_catalog(cat)
101 at.write_to_hdf(filepath_out, cat)
103 if "stars" in self.ctx:
104 filepath_out = par.star_catalog_name
105 cat = catalog_to_rec(self.ctx.stars)
106 at.write_to_hdf(filepath_out, cat)
108 if par.save_SEDs and save_seds_if_requested:
109 filepath_out = par.galaxy_sed_catalog_name
110 cat = sed_catalog_to_rec(self.ctx.galaxies)
111 restframe_wavelength_in_A = self.ctx.restframe_wavelength_for_SED
112 save_sed(filepath_out, cat, restframe_wavelength_in_A)
114 def enrich_catalog(self, cat):
115 par = self.ctx.parameters
116 if par.enrich_catalog is False:
117 LOGGER.debug("Enriching catalog is disabled.")
118 return cat
119 try:
120 cat = at.add_cols(
121 cat, ["e_abs"], data=np.sqrt(cat["e1"] ** 2 + cat["e2"] ** 2)
122 )
123 except (ValueError, KeyError) as e:
124 LOGGER.debug(f"e_abs could not be calculated: {e}")
125 # add noise levels
126 try:
127 cat = at.add_cols(
128 cat, ["bkg_noise_amp"], data=np.ones(len(cat)) * par.bkg_noise_amp
129 )
130 except AttributeError as e:
131 LOGGER.debug(f"bkg_noise_amp could not be calculated: {e}")
133 try:
134 if "ra" not in cat.dtype.names and "dec" not in cat.dtype.names:
135 y = np.array(cat["y"], dtype=int)
136 x = np.array(cat["x"], dtype=int)
137 if hasattr(par.bkg_noise_std, "shape"):
138 cat = at.add_cols(
139 cat, ["bkg_noise_std"], data=par.bkg_noise_std[y, x]
140 )
141 else:
142 cat = at.add_cols(cat, ["bkg_noise_std"], data=par.bkg_noise_std)
143 except (ValueError, KeyError, AttributeError) as e:
144 LOGGER.debug(f"bkg_noise_std could not be calculated: {e}")
145 return cat
147 def __str__(self):
148 return "write ucat catalog to file"