Coverage for src/galsbi/ucat/plugins/write_catalog.py: 88%

40 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-13 03:24 +0000

1# Copyright (C) 2019 ETH Zurich, Institute for Particle Physics and Astrophysics 

2 

3""" 

4Created on Aug 2021 

5author: Tomasz Kacprzak 

6""" 

7 

8import os 

9 

10import numpy as np 

11from cosmic_toolbox import arraytools as at 

12from cosmic_toolbox import logger 

13from ivy.plugin.base_plugin import BasePlugin 

14 

15LOGGER = logger.get_logger(__file__) 

16 

17 

18def get_ucat_catalog_filename(catalog_name): 

19 return catalog_name.replace("ufig", "ucat") 

20 

21 

22def catalog_to_rec(catalog): 

23 # get dtype first 

24 dtype_list = [] 

25 for col_name in catalog.columns: 

26 col = getattr(catalog, col_name) 

27 n_obj = len(col) 

28 if len(col.shape) == 1: 

29 dtype_list += [(col_name, col.dtype)] 

30 else: 

31 dtype_list += [(col_name, col.dtype, col.shape[1])] 

32 

33 # create empty array 

34 rec = np.empty(n_obj, dtype=np.dtype(dtype_list)) 

35 

36 # copy columns to array 

37 for col_name in catalog.columns: 

38 col = getattr(catalog, col_name) 

39 if len(col.shape) == 1: 

40 rec[col_name] = col 

41 elif col.shape[1] == 1: 

42 rec[col_name] = col.ravel() 

43 else: 

44 rec[col_name] = col 

45 

46 return rec 

47 

48 

49class Plugin(BasePlugin): 

50 def __call__(self): 

51 par = self.ctx.parameters 

52 # make output dirs if needed 

53 if not os.path.isdir(par.filepath_tile): 

54 os.makedirs(par.filepath_tile) 

55 

56 # write catalogs 

57 if "galaxies" in self.ctx: 

58 filepath_out = os.path.join( 

59 par.filepath_tile, 

60 get_ucat_catalog_filename(par.galaxy_catalog_name), 

61 ) 

62 

63 cat = catalog_to_rec(self.ctx.galaxies) 

64 at.write_to_hdf(filepath_out, cat) 

65 

66 if "stars" in self.ctx: 

67 filepath_out = os.path.join( 

68 par.filepath_tile, get_ucat_catalog_filename(par.star_catalog_name) 

69 ) 

70 cat = catalog_to_rec(self.ctx.stars) 

71 at.write_to_hdf(filepath_out, cat) 

72 

73 def __str__(self): 

74 return "write ucat catalog to file"