Coverage for src/galsbi/ucat/plugins/write_catalog.py: 100%

91 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-08-26 10:41 +0000

1# Copyright (C) 2019 ETH Zurich, Institute for Particle Physics and Astrophysics 

2 

3""" 

4Created on Aug 2021 

5author: Tomasz Kacprzak 

6""" 

7 

8import h5py 

9import numpy as np 

10from cosmic_toolbox import arraytools as at 

11from cosmic_toolbox import logger 

12from ivy.plugin.base_plugin import BasePlugin 

13 

14LOGGER = logger.get_logger(__file__) 

15 

16 

17def catalog_to_rec(catalog): 

18 # get dtype first 

19 dtype_list = [] 

20 for col_name in catalog.columns: 

21 col = getattr(catalog, col_name) 

22 n_obj = len(col) 

23 if len(col.shape) == 1: 

24 dtype_list += [(col_name, col.dtype)] 

25 else: 

26 dtype_list += [(col_name, col.dtype, col.shape[1])] 

27 

28 # create empty array 

29 rec = np.empty(n_obj, dtype=np.dtype(dtype_list)) 

30 

31 # copy columns to array 

32 for col_name in catalog.columns: 

33 col = getattr(catalog, col_name) 

34 if len(col.shape) == 1: 

35 rec[col_name] = col 

36 elif col.shape[1] == 1: 

37 rec[col_name] = col.ravel() 

38 else: 

39 rec[col_name] = col 

40 

41 return rec 

42 

43 

44def sed_catalog_to_rec(catalog): 

45 """ 

46 To save the SEDs, we don't need the full catalog, but only the ID, SED and redshift. 

47 The ID is necessary to link the SEDs to the galaxies in the photometric catalog. 

48 The redshift is necessary to adapt the wavelengths of the SEDs to the observed frame 

49 """ 

50 columns = ["id", "sed", "z"] 

51 dtype_list = [] 

52 for col in columns: 

53 arr = getattr(catalog, col) 

54 if len(arr.shape) == 1: 

55 dtype_list.append((col, arr.dtype)) 

56 else: 

57 dtype_list.append((col, arr.dtype, arr.shape[1])) 

58 

59 n_obj = len(getattr(catalog, columns[0])) 

60 rec = np.empty(n_obj, dtype=np.dtype(dtype_list)) 

61 for col in columns: 

62 arr = getattr(catalog, col) 

63 if len(arr.shape) == 1: 

64 rec[col] = arr 

65 elif arr.shape[1] == 1: 

66 rec[col] = arr.ravel() 

67 else: 

68 rec[col] = arr 

69 return rec 

70 

71 

72def save_sed(filepath_out, cat, restframe_wavelength_in_A): 

73 """ 

74 Save the SEDs to a file. 

75 The SEDs are saved in the rest frame, so we need to adapt the wavelenghts to the 

76 observed frame. 

77 """ 

78 with h5py.File(filepath_out, "w") as f: 

79 f.create_dataset("data", data=cat) 

80 f.create_dataset("restframe_wavelength_in_A", data=restframe_wavelength_in_A) 

81 

82 

83class Plugin(BasePlugin): 

84 def __call__(self): 

85 par = self.ctx.parameters 

86 

87 if hasattr(self.ctx, "current_filter"): 

88 # if the current filter is set, use it 

89 # this is the case for the image generation 

90 f = self.ctx.current_filter 

91 save_seds_if_requested = f == par.reference_band 

92 else: 

93 save_seds_if_requested = True 

94 

95 # write catalogs 

96 if "galaxies" in self.ctx: 

97 filepath_out = par.galaxy_catalog_name 

98 cat = catalog_to_rec(self.ctx.galaxies) 

99 

100 cat = self.enrich_catalog(cat) 

101 at.write_to_hdf(filepath_out, cat) 

102 

103 if "stars" in self.ctx: 

104 filepath_out = par.star_catalog_name 

105 cat = catalog_to_rec(self.ctx.stars) 

106 at.write_to_hdf(filepath_out, cat) 

107 

108 if par.save_SEDs and save_seds_if_requested: 

109 filepath_out = par.galaxy_sed_catalog_name 

110 cat = sed_catalog_to_rec(self.ctx.galaxies) 

111 restframe_wavelength_in_A = self.ctx.restframe_wavelength_for_SED 

112 save_sed(filepath_out, cat, restframe_wavelength_in_A) 

113 

114 def enrich_catalog(self, cat): 

115 par = self.ctx.parameters 

116 if par.enrich_catalog is False: 

117 LOGGER.debug("Enriching catalog is disabled.") 

118 return cat 

119 try: 

120 cat = at.add_cols( 

121 cat, ["e_abs"], data=np.sqrt(cat["e1"] ** 2 + cat["e2"] ** 2) 

122 ) 

123 except (ValueError, KeyError) as e: 

124 LOGGER.debug(f"e_abs could not be calculated: {e}") 

125 # add noise levels 

126 try: 

127 cat = at.add_cols( 

128 cat, ["bkg_noise_amp"], data=np.ones(len(cat)) * par.bkg_noise_amp 

129 ) 

130 except AttributeError as e: 

131 LOGGER.debug(f"bkg_noise_amp could not be calculated: {e}") 

132 

133 try: 

134 if "ra" not in cat.dtype.names and "dec" not in cat.dtype.names: 

135 y = np.array(cat["y"], dtype=int) 

136 x = np.array(cat["x"], dtype=int) 

137 if hasattr(par.bkg_noise_std, "shape"): 

138 cat = at.add_cols( 

139 cat, ["bkg_noise_std"], data=par.bkg_noise_std[y, x] 

140 ) 

141 else: 

142 cat = at.add_cols(cat, ["bkg_noise_std"], data=par.bkg_noise_std) 

143 except (ValueError, KeyError, AttributeError) as e: 

144 LOGGER.debug(f"bkg_noise_std could not be calculated: {e}") 

145 return cat 

146 

147 def __str__(self): 

148 return "write ucat catalog to file"