Coverage for src/ufig/plugins/run_emulator.py: 100%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-12 19:08 +0000

1# Copyright (C) 2023 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4 

5import h5py 

6from cosmic_toolbox import arraytools as at 

7from cosmic_toolbox import logger 

8from ivy.plugin.base_plugin import BasePlugin 

9 

10HDF5_COMPRESS = {"compression": "gzip", "compression_opts": 9, "shuffle": True} 

11 

12LOGGER = logger.get_logger(__file__) 

13 

14 

15class Plugin(BasePlugin): 

16 def __call__(self): 

17 LOGGER.info("Running the full emulator") 

18 par = self.ctx.parameters 

19 

20 # load classifier and config 

21 conf = par.emu_conf 

22 clf_params = list(par.clf.params) 

23 

24 # load catalogs 

25 cats = {} 

26 for f in par.emu_filters: 

27 cats[f] = h5py.File(par.det_clf_catalog_name_dict[f], "r") 

28 

29 # create classifier input 

30 X = {} 

31 for p in conf["input_band_dep"]: 

32 for f in par.emu_filters: 

33 X[p + f"_{f}"] = cats[f]["data"][p] 

34 for p in conf["input_band_indep"]: 

35 X[p] = cats[par.emu_filters[0]]["data"][p] 

36 

37 # Get ucat params that are not part of the emulator 

38 if "x" in cats[par.emu_filters[0]]["data"].dtype.names: 

39 x = cats[par.emu_filters[0]]["data"]["x"] 

40 y = cats[par.emu_filters[0]]["data"]["y"] 

41 radec = False 

42 else: 

43 radec = True 

44 x = cats[par.emu_filters[0]]["data"]["ra"] 

45 y = cats[par.emu_filters[0]]["data"]["dec"] 

46 ids = cats[par.emu_filters[0]]["data"]["id"] 

47 

48 for f in par.emu_filters: 

49 cats[f].close() 

50 

51 X = at.dict2rec(X) 

52 # Reorder the fields to match the desired order 

53 X = X[clf_params] 

54 

55 LOGGER.debug("Running the classifier") 

56 det = par.clf.predict(X) 

57 

58 LOGGER.debug("Running the normalizing flow") 

59 sexcat = par.nflow.sample(X[det]) 

60 select = at.get_finite_mask(sexcat) 

61 LOGGER.debug("Writing the catalogs") 

62 # write catalog catalog 

63 for f in par.emu_filters: 

64 cat = {} 

65 for p in conf["input_band_dep"]: 

66 cat[p] = sexcat[p + f"_{f}"] 

67 for p in conf["input_band_indep"]: 

68 cat[p] = sexcat[p] 

69 for p in conf["output"]: 

70 cat[p] = sexcat[p + f"_{f}"] 

71 

72 # positions for final catalog 

73 if radec: 

74 cat["ra"] = x[det] 

75 cat["dec"] = y[det] 

76 else: 

77 cat["x"] = x[det] 

78 cat["y"] = y[det] 

79 

80 cat["XWIN_IMAGE"] = x[det] 

81 cat["YWIN_IMAGE"] = y[det] 

82 cat["X_IMAGE"] = x[det] 

83 cat["Y_IMAGE"] = y[det] 

84 

85 cat["id"] = ids[det] 

86 

87 cat = at.dict2rec(cat) 

88 

89 # add NO_MATCH_VAL to stars 

90 # stars = cat["galaxy_type"] == -1 

91 # cat["z"][stars] = NO_MATCH_VAL 

92 

93 # write catalog 

94 filepath = par.sextractor_forced_photo_catalog_name_dict[f] 

95 at.save_hdf_cols(filepath, cat[select], compression=HDF5_COMPRESS) 

96 

97 # del par.clf 

98 # del par.nflow 

99 del self.ctx.parameters.clf 

100 del self.ctx.parameters.nflow 

101 

102 def __str__(self): 

103 return "run the full emulator"