Coverage for src/galsbi/ucat/sed_templates_util.py: 98%
90 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-18 15:15 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-18 15:15 +0000
1# Copyright (C) 2018 ETH Zurich, Institute for Particle Physics and Astrophysics
3"""
4Created Aug 2021
5author: Tomasz Kacprzak
6code from:
7Joerg Herbel
8# ucat/docs/jupyter_notebooks/tabulate_template_integrals.ipynb
9"""
11import itertools
12import os
13from collections import OrderedDict
15import h5py
17# See Fast calculation of magnitudes from spectra
18# ucat/docs/jupyter_notebooks/tabulate_template_integrals.ipynb
19import numpy as np
20from cosmic_toolbox import file_utils, logger
22LOGGER = logger.get_logger(__file__)
25def get_redshift_extinction_grid(
26 z_max, z_stepsize, excess_b_v_max, excess_b_v_stepsize
27):
28 # Redshift grid
29 z_grid = np.arange(0, z_max + z_stepsize, step=z_stepsize)
31 # E(B-V) grid
32 excess_b_v_grid = np.arange(
33 0, excess_b_v_max + excess_b_v_stepsize, excess_b_v_stepsize
34 )
36 ze_cross = np.array(list(itertools.product(z_grid, excess_b_v_grid)))
38 z_cross = ze_cross[:, 0]
39 excess_b_v_cross = ze_cross[:, 1]
41 return z_grid, excess_b_v_grid, z_cross, excess_b_v_cross
44def get_template_integrals(
45 sed_templates, filters, filter_names=None, ids_templates=None, test=False
46):
47 if test:
48 (
49 z_grid,
50 excess_b_v_grid,
51 z_cross,
52 excess_b_v_cross,
53 ) = get_redshift_extinction_grid(
54 z_stepsize=0.05, excess_b_v_max=0.4, excess_b_v_stepsize=0.05, z_max=7.462
55 )
56 else: # pragma: no cover
57 (
58 z_grid,
59 excess_b_v_grid,
60 z_cross,
61 excess_b_v_cross,
62 ) = get_redshift_extinction_grid(
63 z_stepsize=0.001, excess_b_v_max=0.4, excess_b_v_stepsize=0.001, z_max=7.462
64 )
66 coeffs = np.ones(len(z_cross))
67 filter_names = filters.keys() if filter_names is None else filter_names
68 ids_templates = (
69 range(sed_templates["n_templates"]) if ids_templates is None else ids_templates
70 )
72 from galsbi.ucat.magnitude_calculator import MagCalculatorDirect
74 sed_template_integrals = OrderedDict({f: {} for f in filter_names})
76 for i, id_templ in enumerate(ids_templates):
77 LOGGER.info(f"SED template id_templ={id_templ} {i + 1}/{len(ids_templates)}")
78 sed_templates_current = OrderedDict(
79 lam=sed_templates["lam"],
80 amp=sed_templates["amp"][[id_templ]],
81 n_templates=1,
82 )
83 mag_calc = MagCalculatorDirect(filters, sed_templates_current)
84 fluxes = mag_calc(
85 redshifts=z_cross,
86 excess_b_v=excess_b_v_cross,
87 filter_names=filter_names,
88 coeffs=coeffs,
89 return_fluxes=True,
90 )
92 for f in fluxes:
93 flux_reshape = fluxes[f].reshape([len(z_grid), len(excess_b_v_grid)])
94 sed_template_integrals[f][id_templ] = flux_reshape
96 return sed_template_integrals, excess_b_v_grid, z_grid
99def store_sed_integrals(filename, integrals, excess_b_v_grid, z_grid):
100 # get absolute path
101 filename = os.path.abspath(filename)
102 d = os.path.dirname(filename)
103 if not os.path.isdir(d):
104 os.makedirs(d)
105 LOGGER.info(f"made dir {d}")
107 with h5py.File(filename, "w") as f:
108 f["E(B-V)"] = excess_b_v_grid
109 f["z"] = z_grid
110 for b, templates in integrals.items():
111 for t in templates:
112 f[f"integrals/{b}/template_{t}"] = integrals[b][t]
113 LOGGER.info(f"wrote {filename}")
116def load_sed_integrals(
117 filepath_sed_integ,
118 filter_names=None,
119 crop_negative=False,
120 sed_templates=None,
121 copy_to_cwd=False,
122):
123 """
124 Loads SED integrals, uses cache.
125 :param filepath_sed_integ: name of the file containing SED template integrals
126 :param filter_names: list of filter name, should be in format Camera_band, but just
127 band is also accepted to ensure backwards compatibility
129 :param crop_negative: if to set all negative elements in the filter to zero
130 :param sed_templates: OrderedDict containing a buffer for templates
131 :param copy_to_cwd: copy the file to the current working directory
132 """
134 def filter_name_back_compatibility(filter_name, integrals):
135 # backwards compatibility check
136 if filter_name not in integrals:
137 return filter_name.split("_")[-1]
138 else:
139 return filter_name
141 if sed_templates is None:
142 sed_templates = OrderedDict()
144 filepath_sed_integ_local = os.path.join(
145 os.getcwd(), os.path.basename(filepath_sed_integ)
146 )
148 # copy to local directory (local scratch) if not already there
149 if copy_to_cwd and (not os.path.exists(filepath_sed_integ_local)):
150 src = filepath_sed_integ
151 file_utils.robust_copy(src, filepath_sed_integ_local)
152 load_filename = filepath_sed_integ_local
153 elif copy_to_cwd:
154 load_filename = filepath_sed_integ_local
155 else:
156 load_filename = filepath_sed_integ
158 with h5py.File(load_filename, mode="r") as f:
159 for filter_name in filter_names:
160 # only load if not already in buffer
161 if filter_name not in sed_templates:
162 # backwards compatibility check
163 filter_name_use = filter_name_back_compatibility(
164 filter_name, f["integrals"]
165 )
166 sed_templates[filter_name_use] = []
168 n_templ = len(f["integrals"][filter_name_use].keys())
170 for i in range(n_templ):
171 int_templ = np.array(
172 f["integrals"][filter_name_use][f"template_{i}"]
173 )
174 if crop_negative:
175 np.clip(int_templ, a_min=0, a_max=None, out=int_templ)
176 sed_templates[filter_name_use] += [int_templ]
177 else:
178 LOGGER.debug(f"read {filter_name} from sed_templates cache")
179 if not hasattr(sed_templates, "n_templates"):
180 sed_templates.n_templates = len(f["integrals"][filter_name_use].keys())
182 if not hasattr(sed_templates, "z_grid"):
183 sed_templates.z_grid = np.array(f["z"])
185 if not hasattr(sed_templates, "excess_b_v_grid"):
186 sed_templates.excess_b_v_grid = np.array(f["E(B-V)"])
187 return sed_templates
190def load_template_spectra(filepath_sed_templates, lam_scale=1, amp_scale=1):
191 with h5py.File(filepath_sed_templates, "r") as f:
192 sed_templates_lam = np.array(f["wavelength"])
193 sed_templates_amp = np.array(f["amplitudes"])
194 LOGGER.info(f"using template set: {filepath_sed_templates}")
195 for a in f["amplitudes"].attrs:
196 LOGGER.info("templates {}:\n{}".format(a, f["amplitudes"].attrs[a]))
198 sed_templates_lam *= lam_scale # scale by the desired factor
199 sed_templates_amp *= amp_scale # scale by the desired factor
200 sed_templates = OrderedDict(amp=sed_templates_amp, lam=sed_templates_lam)
201 sed_templates.n_templates = sed_templates_amp.shape[0]
202 return sed_templates