Coverage for src/galsbi/ucat/sed_templates_util.py: 98%

90 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-18 15:15 +0000

1# Copyright (C) 2018 ETH Zurich, Institute for Particle Physics and Astrophysics 

2 

3""" 

4Created Aug 2021 

5author: Tomasz Kacprzak 

6code from: 

7Joerg Herbel 

8# ucat/docs/jupyter_notebooks/tabulate_template_integrals.ipynb 

9""" 

10 

11import itertools 

12import os 

13from collections import OrderedDict 

14 

15import h5py 

16 

17# See Fast calculation of magnitudes from spectra 

18# ucat/docs/jupyter_notebooks/tabulate_template_integrals.ipynb 

19import numpy as np 

20from cosmic_toolbox import file_utils, logger 

21 

22LOGGER = logger.get_logger(__file__) 

23 

24 

25def get_redshift_extinction_grid( 

26 z_max, z_stepsize, excess_b_v_max, excess_b_v_stepsize 

27): 

28 # Redshift grid 

29 z_grid = np.arange(0, z_max + z_stepsize, step=z_stepsize) 

30 

31 # E(B-V) grid 

32 excess_b_v_grid = np.arange( 

33 0, excess_b_v_max + excess_b_v_stepsize, excess_b_v_stepsize 

34 ) 

35 

36 ze_cross = np.array(list(itertools.product(z_grid, excess_b_v_grid))) 

37 

38 z_cross = ze_cross[:, 0] 

39 excess_b_v_cross = ze_cross[:, 1] 

40 

41 return z_grid, excess_b_v_grid, z_cross, excess_b_v_cross 

42 

43 

44def get_template_integrals( 

45 sed_templates, filters, filter_names=None, ids_templates=None, test=False 

46): 

47 if test: 

48 ( 

49 z_grid, 

50 excess_b_v_grid, 

51 z_cross, 

52 excess_b_v_cross, 

53 ) = get_redshift_extinction_grid( 

54 z_stepsize=0.05, excess_b_v_max=0.4, excess_b_v_stepsize=0.05, z_max=7.462 

55 ) 

56 else: # pragma: no cover 

57 ( 

58 z_grid, 

59 excess_b_v_grid, 

60 z_cross, 

61 excess_b_v_cross, 

62 ) = get_redshift_extinction_grid( 

63 z_stepsize=0.001, excess_b_v_max=0.4, excess_b_v_stepsize=0.001, z_max=7.462 

64 ) 

65 

66 coeffs = np.ones(len(z_cross)) 

67 filter_names = filters.keys() if filter_names is None else filter_names 

68 ids_templates = ( 

69 range(sed_templates["n_templates"]) if ids_templates is None else ids_templates 

70 ) 

71 

72 from galsbi.ucat.magnitude_calculator import MagCalculatorDirect 

73 

74 sed_template_integrals = OrderedDict({f: {} for f in filter_names}) 

75 

76 for i, id_templ in enumerate(ids_templates): 

77 LOGGER.info(f"SED template id_templ={id_templ} {i + 1}/{len(ids_templates)}") 

78 sed_templates_current = OrderedDict( 

79 lam=sed_templates["lam"], 

80 amp=sed_templates["amp"][[id_templ]], 

81 n_templates=1, 

82 ) 

83 mag_calc = MagCalculatorDirect(filters, sed_templates_current) 

84 fluxes = mag_calc( 

85 redshifts=z_cross, 

86 excess_b_v=excess_b_v_cross, 

87 filter_names=filter_names, 

88 coeffs=coeffs, 

89 return_fluxes=True, 

90 ) 

91 

92 for f in fluxes: 

93 flux_reshape = fluxes[f].reshape([len(z_grid), len(excess_b_v_grid)]) 

94 sed_template_integrals[f][id_templ] = flux_reshape 

95 

96 return sed_template_integrals, excess_b_v_grid, z_grid 

97 

98 

99def store_sed_integrals(filename, integrals, excess_b_v_grid, z_grid): 

100 # get absolute path 

101 filename = os.path.abspath(filename) 

102 d = os.path.dirname(filename) 

103 if not os.path.isdir(d): 

104 os.makedirs(d) 

105 LOGGER.info(f"made dir {d}") 

106 

107 with h5py.File(filename, "w") as f: 

108 f["E(B-V)"] = excess_b_v_grid 

109 f["z"] = z_grid 

110 for b, templates in integrals.items(): 

111 for t in templates: 

112 f[f"integrals/{b}/template_{t}"] = integrals[b][t] 

113 LOGGER.info(f"wrote {filename}") 

114 

115 

116def load_sed_integrals( 

117 filepath_sed_integ, 

118 filter_names=None, 

119 crop_negative=False, 

120 sed_templates=None, 

121 copy_to_cwd=False, 

122): 

123 """ 

124 Loads SED integrals, uses cache. 

125 :param filepath_sed_integ: name of the file containing SED template integrals 

126 :param filter_names: list of filter name, should be in format Camera_band, but just 

127 band is also accepted to ensure backwards compatibility 

128 

129 :param crop_negative: if to set all negative elements in the filter to zero 

130 :param sed_templates: OrderedDict containing a buffer for templates 

131 :param copy_to_cwd: copy the file to the current working directory 

132 """ 

133 

134 def filter_name_back_compatibility(filter_name, integrals): 

135 # backwards compatibility check 

136 if filter_name not in integrals: 

137 return filter_name.split("_")[-1] 

138 else: 

139 return filter_name 

140 

141 if sed_templates is None: 

142 sed_templates = OrderedDict() 

143 

144 filepath_sed_integ_local = os.path.join( 

145 os.getcwd(), os.path.basename(filepath_sed_integ) 

146 ) 

147 

148 # copy to local directory (local scratch) if not already there 

149 if copy_to_cwd and (not os.path.exists(filepath_sed_integ_local)): 

150 src = filepath_sed_integ 

151 file_utils.robust_copy(src, filepath_sed_integ_local) 

152 load_filename = filepath_sed_integ_local 

153 elif copy_to_cwd: 

154 load_filename = filepath_sed_integ_local 

155 else: 

156 load_filename = filepath_sed_integ 

157 

158 with h5py.File(load_filename, mode="r") as f: 

159 for filter_name in filter_names: 

160 # only load if not already in buffer 

161 if filter_name not in sed_templates: 

162 # backwards compatibility check 

163 filter_name_use = filter_name_back_compatibility( 

164 filter_name, f["integrals"] 

165 ) 

166 sed_templates[filter_name_use] = [] 

167 

168 n_templ = len(f["integrals"][filter_name_use].keys()) 

169 

170 for i in range(n_templ): 

171 int_templ = np.array( 

172 f["integrals"][filter_name_use][f"template_{i}"] 

173 ) 

174 if crop_negative: 

175 np.clip(int_templ, a_min=0, a_max=None, out=int_templ) 

176 sed_templates[filter_name_use] += [int_templ] 

177 else: 

178 LOGGER.debug(f"read {filter_name} from sed_templates cache") 

179 if not hasattr(sed_templates, "n_templates"): 

180 sed_templates.n_templates = len(f["integrals"][filter_name_use].keys()) 

181 

182 if not hasattr(sed_templates, "z_grid"): 

183 sed_templates.z_grid = np.array(f["z"]) 

184 

185 if not hasattr(sed_templates, "excess_b_v_grid"): 

186 sed_templates.excess_b_v_grid = np.array(f["E(B-V)"]) 

187 return sed_templates 

188 

189 

190def load_template_spectra(filepath_sed_templates, lam_scale=1, amp_scale=1): 

191 with h5py.File(filepath_sed_templates, "r") as f: 

192 sed_templates_lam = np.array(f["wavelength"]) 

193 sed_templates_amp = np.array(f["amplitudes"]) 

194 LOGGER.info(f"using template set: {filepath_sed_templates}") 

195 for a in f["amplitudes"].attrs: 

196 LOGGER.info("templates {}:\n{}".format(a, f["amplitudes"].attrs[a])) 

197 

198 sed_templates_lam *= lam_scale # scale by the desired factor 

199 sed_templates_amp *= amp_scale # scale by the desired factor 

200 sed_templates = OrderedDict(amp=sed_templates_amp, lam=sed_templates_lam) 

201 sed_templates.n_templates = sed_templates_amp.shape[0] 

202 return sed_templates