Coverage for src/ufig/plugins/run_nflow.py: 92%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-12 19:08 +0000

1# Copyright (C) 2023 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4 

5import contextlib 

6 

7import h5py 

8from cosmic_toolbox import arraytools as at 

9from cosmic_toolbox import logger 

10from ivy.plugin.base_plugin import BasePlugin 

11 

12from ufig.plugins.match_sextractor_catalog_multiband_read import NO_MATCH_VAL 

13from ufig.plugins.write_catalog_for_emu import enrich_star_catalog 

14 

15HDF5_COMPRESS = {"compression": "gzip", "compression_opts": 9, "shuffle": True} 

16 

17LOGGER = logger.get_logger(__file__) 

18 

19 

20class Plugin(BasePlugin): 

21 def __call__(self): 

22 par = self.ctx.parameters 

23 

24 filter = self.ctx.current_filter 

25 

26 # get classifier and config 

27 nflow = par.nflow[filter] 

28 

29 # load catalogs 

30 with h5py.File(par.galaxy_catalog_name_dict[filter], "r") as fh5: 

31 cat = fh5["data"][:] 

32 

33 # remove x and y 

34 x = cat["x"] 

35 y = cat["y"] 

36 cat = at.delete_columns(cat, ["x", "y"]) 

37 

38 sexcat = nflow.sample(cat) 

39 

40 # add x and y and pretend they are measured by sextractor 

41 sexcat = at.add_cols(sexcat, ["XWIN_IMAGE", "YWIN_IMAGE", "X_IMAGE", "Y_IMAGE"]) 

42 sexcat["XWIN_IMAGE"] = x 

43 sexcat["YWIN_IMAGE"] = y 

44 sexcat["X_IMAGE"] = x 

45 sexcat["Y_IMAGE"] = y 

46 

47 # add NO_MATCH_VAL to stars 

48 _, enriched_params1 = enrich_star_catalog(cat=None, par=par) 

49 _, enriched_params2 = enrich_star_catalog(cat=None, par=par) 

50 stars = sexcat["galaxy_type"] == -1 

51 enriched_params = enriched_params1 + enriched_params2 

52 # TODO: for now don't do it 

53 enriched_params = [] 

54 for p in enriched_params: 54 ↛ 55line 54 didn't jump to line 55 because the loop on line 54 never started

55 with contextlib.suppress(KeyError): 

56 sexcat[p][stars] = NO_MATCH_VAL 

57 sexcat["z"][stars] = NO_MATCH_VAL 

58 

59 # write catalog 

60 filepath = par.sextractor_forced_photo_catalog_name_dict[ 

61 self.ctx.current_filter 

62 ] 

63 at.save_hdf_cols(filepath, sexcat, compression=HDF5_COMPRESS) 

64 

65 def __str__(self): 

66 return "run the normalizing flow"