Coverage for src/galsbi/ucat/sed_templates_util.py: 92%
90 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-10 11:12 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-10 11:12 +0000
1# Copyright (C) 2018 ETH Zurich, Institute for Particle Physics and Astrophysics
3"""
4Created Aug 2021
5author: Tomasz Kacprzak
6code from:
7Joerg Herbel
8# ucat/docs/jupyter_notebooks/tabulate_template_integrals.ipynb
9"""
11import itertools
12import os
13from collections import OrderedDict
15import h5py
17# See Fast calculation of magnitudes from spectra
18# ucat/docs/jupyter_notebooks/tabulate_template_integrals.ipynb
19import numpy as np
20from cosmic_toolbox import file_utils, logger
22LOGGER = logger.get_logger(__file__)
25def get_redshift_extinction_grid(
26 z_max, z_stepsize, excess_b_v_max, excess_b_v_stepsize
27):
28 # Redshift grid
29 z_grid = np.arange(0, z_max + z_stepsize, step=z_stepsize)
31 # E(B-V) grid
32 excess_b_v_grid = np.arange(
33 0, excess_b_v_max + excess_b_v_stepsize, excess_b_v_stepsize
34 )
36 ze_cross = np.array(list(itertools.product(z_grid, excess_b_v_grid)))
38 z_cross = ze_cross[:, 0]
39 excess_b_v_cross = ze_cross[:, 1]
41 return z_grid, excess_b_v_grid, z_cross, excess_b_v_cross
44def get_template_integrals(
45 sed_templates, filters, filter_names=None, ids_templates=None, test=False
46):
47 if test:
48 (
49 z_grid,
50 excess_b_v_grid,
51 z_cross,
52 excess_b_v_cross,
53 ) = get_redshift_extinction_grid(
54 z_stepsize=0.05, excess_b_v_max=0.4, excess_b_v_stepsize=0.05, z_max=7.462
55 )
56 else: # pragma: no cover
57 (
58 z_grid,
59 excess_b_v_grid,
60 z_cross,
61 excess_b_v_cross,
62 ) = get_redshift_extinction_grid(
63 z_stepsize=0.001, excess_b_v_max=0.4, excess_b_v_stepsize=0.001, z_max=7.462
64 )
66 coeffs = np.ones(len(z_cross))
67 filter_names = filters.keys() if filter_names is None else filter_names
68 ids_templates = (
69 range(sed_templates["n_templates"]) if ids_templates is None else ids_templates
70 )
72 from galsbi.ucat.magnitude_calculator import MagCalculatorDirect
74 sed_template_integrals = OrderedDict({f: {} for f in filter_names})
76 for i, id_templ in enumerate(ids_templates):
77 LOGGER.info(f"SED template id_templ={id_templ} {i + 1}/{len(ids_templates)}")
78 sed_templates_current = OrderedDict(
79 lam=sed_templates["lam"],
80 amp=sed_templates["amp"][[id_templ]],
81 n_templates=1,
82 )
83 mag_calc = MagCalculatorDirect(filters, sed_templates_current)
84 fluxes = mag_calc(
85 redshifts=z_cross,
86 excess_b_v=excess_b_v_cross,
87 filter_names=filter_names,
88 coeffs=coeffs,
89 return_fluxes=True,
90 )
92 for f in fluxes:
93 flux_reshape = fluxes[f].reshape([len(z_grid), len(excess_b_v_grid)])
94 sed_template_integrals[f][id_templ] = flux_reshape
96 return sed_template_integrals, excess_b_v_grid, z_grid
99def store_sed_integrals(filename, integrals, excess_b_v_grid, z_grid):
100 # get absolute path
101 filename = os.path.abspath(filename)
102 d = os.path.dirname(filename)
103 if not os.path.isdir(d):
104 os.makedirs(d)
105 LOGGER.info(f"made dir {d}")
107 with h5py.File(filename, "w") as f:
108 f["E(B-V)"] = excess_b_v_grid
109 f["z"] = z_grid
110 for b, templates in integrals.items():
111 for t in templates:
112 f[f"integrals/{b}/template_{t}"] = integrals[b][t]
113 LOGGER.info(f"wrote {filename}")
116def load_sed_integrals(
117 filepath_sed_integ,
118 filter_names=None,
119 crop_negative=False,
120 sed_templates=None,
121 copy_to_cwd=False,
122):
123 """
124 Loads SED integrals, uses cache.
125 :param filepath_sed_integ: name of the file containing SED template integrals
126 :param filter_names: list of filter name, should be in format Camera_band, but just
127 band is also accepted to ensure backwards compatibility
128 :param crop_negative: if to set all negative elements in the filter to zero
129 :param sed_templates: OrderedDict containing a buffer for templates
130 :param copy_to_cwd: copy the file to the current working directory
131 """
133 def filter_name_back_compatibility(filter_name, integrals):
134 # backwards compatibility check
135 if filter_name not in integrals:
136 return filter_name.split("_")[-1]
137 else:
138 return filter_name
140 if sed_templates is None:
141 sed_templates = OrderedDict()
143 filepath_sed_integ_local = os.path.join(
144 os.getcwd(), os.path.basename(filepath_sed_integ)
145 )
147 # copy to local directory (local scratch) if not already there
148 if copy_to_cwd and (not os.path.exists(filepath_sed_integ_local)):
149 src = filepath_sed_integ
150 file_utils.robust_copy(src, filepath_sed_integ_local)
151 load_filename = filepath_sed_integ_local
152 elif copy_to_cwd:
153 load_filename = filepath_sed_integ_local
154 else:
155 load_filename = filepath_sed_integ
157 with h5py.File(load_filename, mode="r") as f:
158 for filter_name in filter_names:
159 # only load if not already in buffer
160 if filter_name not in sed_templates:
161 # backwards compatibility check
162 filter_name_use = filter_name_back_compatibility(
163 filter_name, f["integrals"]
164 )
165 sed_templates[filter_name_use] = []
167 n_templ = len(f["integrals"][filter_name_use].keys())
169 for i in range(n_templ):
170 int_templ = np.array(
171 f["integrals"][filter_name_use][f"template_{i}"]
172 )
173 if crop_negative:
174 np.clip(int_templ, a_min=0, a_max=None, out=int_templ)
175 sed_templates[filter_name_use] += [int_templ]
176 else:
177 LOGGER.debug(f"read {filter_name} from sed_templates cache")
178 if not hasattr(sed_templates, "n_templates"):
179 sed_templates.n_templates = len(f["integrals"][filter_name_use].keys())
181 if not hasattr(sed_templates, "z_grid"):
182 sed_templates.z_grid = np.array(f["z"])
184 if not hasattr(sed_templates, "excess_b_v_grid"):
185 sed_templates.excess_b_v_grid = np.array(f["E(B-V)"])
186 return sed_templates
189def load_template_spectra(filepath_sed_templates, lam_scale=1, amp_scale=1):
190 with h5py.File(filepath_sed_templates, "r") as f:
191 sed_templates_lam = np.array(f["wavelength"])
192 sed_templates_amp = np.array(f["amplitudes"])
193 LOGGER.info(f"using template set: {filepath_sed_templates}")
194 for a in f["amplitudes"].attrs:
195 LOGGER.info("templates {}:\n{}".format(a, f["amplitudes"].attrs[a]))
197 sed_templates_lam *= lam_scale # scale by the desired factor
198 sed_templates_amp *= amp_scale # scale by the desired factor
199 sed_templates = OrderedDict(amp=sed_templates_amp, lam=sed_templates_lam)
200 sed_templates.n_templates = sed_templates_amp.shape[0]
201 return sed_templates