Coverage for src/ufig/plugins/run_detection_classifier.py: 100%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-12 19:08 +0000

1# Copyright (C) 2023 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4 

5import h5py 

6from cosmic_toolbox import arraytools as at 

7from cosmic_toolbox import file_utils, logger 

8from ivy.plugin.base_plugin import BasePlugin 

9 

10HDF5_COMPRESS = {"compression": "gzip", "compression_opts": 9, "shuffle": True} 

11 

12LOGGER = logger.get_logger(__file__) 

13 

14 

15class Plugin(BasePlugin): 

16 def __call__(self): 

17 par = self.ctx.parameters 

18 

19 # load classifier and config 

20 clf = par.clf 

21 conf = par.emu_conf 

22 clf_params = list(clf.params) 

23 

24 # load catalogs 

25 cats = {} 

26 for f in par.emu_filters: 

27 cats[f] = h5py.File(par.det_clf_catalog_name_dict[f], "r") 

28 

29 # create classifier input 

30 X = {} 

31 for p in conf["input_band_dep"]: 

32 for f in par.emu_filters: 

33 X[p + f"_{f}"] = cats[f]["data"][p] 

34 for p in conf["input_band_indep"]: 

35 X[p] = cats[par.filters[0]]["data"][p] 

36 

37 # Get positions 

38 x = cats[par.filters[0]]["data"]["x"] 

39 y = cats[par.filters[0]]["data"]["y"] 

40 

41 # close catalogs 

42 for f in par.emu_filters: 

43 cats[f].close() 

44 

45 # check if parameters match 

46 X = at.dict2rec(X) 

47 # Reorder the fields to match the desired order 

48 X = X[clf_params] 

49 

50 # run classifier 

51 det = clf.predict(X) 

52 

53 # write detection catalog 

54 for f in par.emu_filters: 

55 cat = {} 

56 for p in conf["input_band_dep"]: 

57 cat[p] = X[p + f"_{f}"][det] 

58 for p in conf["input_band_indep"]: 

59 cat[p] = X[p][det] 

60 

61 # positions for final catalog 

62 cat["x"] = x[det] 

63 cat["y"] = y[det] 

64 

65 cat = at.dict2rec(cat) 

66 file_utils.write_to_hdf( 

67 par.galaxy_catalog_name_dict[f], cat, "data", **HDF5_COMPRESS 

68 ) 

69 LOGGER.info("Saved detection catalogs") 

70 

71 def __str__(self): 

72 return "run detection classifier"