Coverage for src/galsbi/ucat/sed_templates_util.py: 92%

90 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-13 03:24 +0000

1# Copyright (C) 2018 ETH Zurich, Institute for Particle Physics and Astrophysics 

2 

3""" 

4Created Aug 2021 

5author: Tomasz Kacprzak 

6code from: 

7Joerg Herbel 

8# ucat/docs/jupyter_notebooks/tabulate_template_integrals.ipynb 

9""" 

10 

11import itertools 

12import os 

13from collections import OrderedDict 

14 

15import h5py 

16 

17# See Fast calculation of magnitudes from spectra 

18# ucat/docs/jupyter_notebooks/tabulate_template_integrals.ipynb 

19import numpy as np 

20from cosmic_toolbox import file_utils, logger 

21 

22LOGGER = logger.get_logger(__file__) 

23 

24 

25def get_redshift_extinction_grid( 

26 z_max, z_stepsize, excess_b_v_max, excess_b_v_stepsize 

27): 

28 # Redshift grid 

29 z_grid = np.arange(0, z_max + z_stepsize, step=z_stepsize) 

30 

31 # E(B-V) grid 

32 excess_b_v_grid = np.arange( 

33 0, excess_b_v_max + excess_b_v_stepsize, excess_b_v_stepsize 

34 ) 

35 

36 ze_cross = np.array(list(itertools.product(z_grid, excess_b_v_grid))) 

37 

38 z_cross = ze_cross[:, 0] 

39 excess_b_v_cross = ze_cross[:, 1] 

40 

41 return z_grid, excess_b_v_grid, z_cross, excess_b_v_cross 

42 

43 

44def get_template_integrals( 

45 sed_templates, filters, filter_names=None, ids_templates=None, test=False 

46): 

47 if test: 

48 ( 

49 z_grid, 

50 excess_b_v_grid, 

51 z_cross, 

52 excess_b_v_cross, 

53 ) = get_redshift_extinction_grid( 

54 z_stepsize=0.05, excess_b_v_max=0.4, excess_b_v_stepsize=0.05, z_max=7.462 

55 ) 

56 else: # pragma: no cover 

57 ( 

58 z_grid, 

59 excess_b_v_grid, 

60 z_cross, 

61 excess_b_v_cross, 

62 ) = get_redshift_extinction_grid( 

63 z_stepsize=0.001, excess_b_v_max=0.4, excess_b_v_stepsize=0.001, z_max=7.462 

64 ) 

65 

66 coeffs = np.ones(len(z_cross)) 

67 filter_names = filters.keys() if filter_names is None else filter_names 

68 ids_templates = ( 

69 range(sed_templates["n_templates"]) if ids_templates is None else ids_templates 

70 ) 

71 

72 from galsbi.ucat.magnitude_calculator import MagCalculatorDirect 

73 

74 sed_template_integrals = OrderedDict({f: {} for f in filter_names}) 

75 

76 for i, id_templ in enumerate(ids_templates): 

77 LOGGER.info(f"SED template id_templ={id_templ} {i + 1}/{len(ids_templates)}") 

78 sed_templates_current = OrderedDict( 

79 lam=sed_templates["lam"], 

80 amp=sed_templates["amp"][[id_templ]], 

81 n_templates=1, 

82 ) 

83 mag_calc = MagCalculatorDirect(filters, sed_templates_current) 

84 fluxes = mag_calc( 

85 redshifts=z_cross, 

86 excess_b_v=excess_b_v_cross, 

87 filter_names=filter_names, 

88 coeffs=coeffs, 

89 return_fluxes=True, 

90 ) 

91 

92 for f in fluxes: 

93 flux_reshape = fluxes[f].reshape([len(z_grid), len(excess_b_v_grid)]) 

94 sed_template_integrals[f][id_templ] = flux_reshape 

95 

96 return sed_template_integrals, excess_b_v_grid, z_grid 

97 

98 

99def store_sed_integrals(filename, integrals, excess_b_v_grid, z_grid): 

100 # get absolute path 

101 filename = os.path.abspath(filename) 

102 d = os.path.dirname(filename) 

103 if not os.path.isdir(d): 

104 os.makedirs(d) 

105 LOGGER.info(f"made dir {d}") 

106 

107 with h5py.File(filename, "w") as f: 

108 f["E(B-V)"] = excess_b_v_grid 

109 f["z"] = z_grid 

110 for b, templates in integrals.items(): 

111 for t in templates: 

112 f[f"integrals/{b}/template_{t}"] = integrals[b][t] 

113 LOGGER.info(f"wrote {filename}") 

114 

115 

116def load_sed_integrals( 

117 filepath_sed_integ, 

118 filter_names=None, 

119 crop_negative=False, 

120 sed_templates=None, 

121 copy_to_cwd=False, 

122): 

123 """ 

124 Loads SED integrals, uses cache. 

125 :param filepath_sed_integ: name of the file containing SED template integrals 

126 :param filter_names: list of filter name, should be in format Camera_band, but just 

127 band is also accepted to ensure backwards compatibility 

128 :param crop_negative: if to set all negative elements in the filter to zero 

129 :param sed_templates: OrderedDict containing a buffer for templates 

130 :param copy_to_cwd: copy the file to the current working directory 

131 """ 

132 

133 def filter_name_back_compatibility(filter_name, integrals): 

134 # backwards compatibility check 

135 if filter_name not in integrals: 

136 return filter_name.split("_")[-1] 

137 else: 

138 return filter_name 

139 

140 if sed_templates is None: 

141 sed_templates = OrderedDict() 

142 

143 filepath_sed_integ_local = os.path.join( 

144 os.getcwd(), os.path.basename(filepath_sed_integ) 

145 ) 

146 

147 # copy to local directory (local scratch) if not already there 

148 if copy_to_cwd and (not os.path.exists(filepath_sed_integ_local)): 

149 src = filepath_sed_integ 

150 file_utils.robust_copy(src, filepath_sed_integ_local) 

151 load_filename = filepath_sed_integ_local 

152 elif copy_to_cwd: 

153 load_filename = filepath_sed_integ_local 

154 else: 

155 load_filename = filepath_sed_integ 

156 

157 with h5py.File(load_filename, mode="r") as f: 

158 for filter_name in filter_names: 

159 # only load if not already in buffer 

160 if filter_name not in sed_templates: 

161 # backwards compatibility check 

162 filter_name_use = filter_name_back_compatibility( 

163 filter_name, f["integrals"] 

164 ) 

165 sed_templates[filter_name_use] = [] 

166 

167 n_templ = len(f["integrals"][filter_name_use].keys()) 

168 

169 for i in range(n_templ): 

170 int_templ = np.array( 

171 f["integrals"][filter_name_use][f"template_{i}"] 

172 ) 

173 if crop_negative: 

174 np.clip(int_templ, a_min=0, a_max=None, out=int_templ) 

175 sed_templates[filter_name_use] += [int_templ] 

176 else: 

177 LOGGER.debug(f"read {filter_name} from sed_templates cache") 

178 if not hasattr(sed_templates, "n_templates"): 

179 sed_templates.n_templates = len(f["integrals"][filter_name_use].keys()) 

180 

181 if not hasattr(sed_templates, "z_grid"): 

182 sed_templates.z_grid = np.array(f["z"]) 

183 

184 if not hasattr(sed_templates, "excess_b_v_grid"): 

185 sed_templates.excess_b_v_grid = np.array(f["E(B-V)"]) 

186 return sed_templates 

187 

188 

189def load_template_spectra(filepath_sed_templates, lam_scale=1, amp_scale=1): 

190 with h5py.File(filepath_sed_templates, "r") as f: 

191 sed_templates_lam = np.array(f["wavelength"]) 

192 sed_templates_amp = np.array(f["amplitudes"]) 

193 LOGGER.info(f"using template set: {filepath_sed_templates}") 

194 for a in f["amplitudes"].attrs: 

195 LOGGER.info("templates {}:\n{}".format(a, f["amplitudes"].attrs[a])) 

196 

197 sed_templates_lam *= lam_scale # scale by the desired factor 

198 sed_templates_amp *= amp_scale # scale by the desired factor 

199 sed_templates = OrderedDict(amp=sed_templates_amp, lam=sed_templates_lam) 

200 sed_templates.n_templates = sed_templates_amp.shape[0] 

201 return sed_templates