Coverage for src / cosmic_toolbox / file_utils.py: 88%

132 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-31 12:38 +0000

1# Copyright (C) 2017 ETH Zurich 

2# Cosmology Research Group 

3# Author: Joerg Herbel 

4 

5""" 

6File utilities for reading, writing, and copying files. 

7 

8Provides functions for: 

9- Reading/writing pickle and HDF5 files with compression 

10- Robust file/directory operations (makedirs, remove, copy) 

11- Remote file operations via SSH/rsync 

12- YAML file handling 

13 

14""" 

15 

16import os 

17import pickle 

18import shlex 

19import shutil 

20import stat 

21import subprocess 

22 

23import copy_guardian 

24import h5py 

25import yaml 

26 

27from cosmic_toolbox.logger import get_logger 

28 

29DEFAULT_ROOT_PATH = "" 

30LOGGER = get_logger(__file__) 

31 

32 

33def robust_remove(path): 

34 """ 

35 Remove a file or directory. 

36 

37 :param path: Path to the file or directory. 

38 """ 

39 if is_remote(path): 

40 LOGGER.info(f"Removing remote directory {path}") 

41 host, path = path.split(":") 

42 cmd = f'ssh {host} "rm -rf {path}"' 

43 subprocess.call(shlex.split(cmd)) 

44 

45 else: 

46 if os.path.isfile(path): 

47 os.remove(path) 

48 

49 elif os.path.isdir(path): 

50 shutil.rmtree(path) 

51 

52 else: 

53 LOGGER.warning(f"Cannot remove {path} because it does not exist") 

54 

55 

56def write_to_pickle(filepath, obj, compression="none"): 

57 """ 

58 Write an object to a pickle file. 

59 

60 :param filepath: Path to the pickle file. 

61 :param obj: Object to write. 

62 :param compression: Compression method to use. Can be "none", "lzf" or "bz2". 

63 """ 

64 if compression.lower() == "none": 

65 with open(filepath, "wb") as f: 

66 pickle.dump(obj, f) 

67 elif compression.lower() == "lzf": 

68 import lzf 

69 

70 with lzf.open(filepath, "wb") as f: 

71 pickle.dump(obj, f) 

72 elif compression.lower() == "bz2": 

73 import bz2 

74 

75 with bz2.open(filepath, "wb") as f: 

76 pickle.dump(obj, f) 

77 

78 else: 

79 raise Exception(f"uknown compression {compression} [none, lzf, bz2]") 

80 

81 

82def read_from_pickle(filepath, compression="none"): 

83 """ 

84 Read an object from a pickle file. 

85 

86 :param filepath: Path to the pickle file. 

87 :param compression: Compression method to use. Can be "none", "lzf" or "bz2". 

88 :return: Object read from the pickle file. 

89 """ 

90 if compression.lower() == "none": 

91 with open(filepath, "rb") as f: 

92 obj = pickle.load(f) 

93 elif compression.lower() == "lzf": 

94 import lzf 

95 

96 with lzf.open(filepath, "rb") as f: 

97 obj = pickle.load(f) 

98 elif compression.lower() == "bz2": 

99 import bz2 

100 

101 with bz2.open(filepath, "rb") as f: 

102 obj = pickle.load(f) 

103 

104 else: 

105 raise Exception(f"uknown compression {compression} [none, lzf, bz2]") 

106 

107 return obj 

108 

109 

110def write_to_hdf(filepath, obj, name="data", **kwargs): 

111 """ 

112 Write an object to an hdf5 file. 

113 

114 :param filepath: Path to the hdf5 file. 

115 :param obj: Object to write. 

116 :param name: Name of the dataset. 

117 :param kwargs: Additional arguments passed to h5py.File.create_dataset. 

118 """ 

119 

120 with h5py.File(filepath, "w") as f: 

121 f.create_dataset(name, data=obj, **kwargs) 

122 

123 

124def read_from_hdf(filepath, name="data"): 

125 """ 

126 Read an object from an hdf5 file. 

127 

128 :param filepath: Path to the hdf5 file. 

129 :param name: Name of the dataset. 

130 :return: Object read from the hdf5 file. 

131 """ 

132 

133 with h5py.File(filepath, "r") as f: 

134 obj = f[name][:] 

135 

136 return obj 

137 

138 

139def load_from_hdf5(file_name, hdf5_keys, hdf5_path=""): 

140 """ 

141 Load data stored in a HDF5-file. 

142 :param file_name: Name of the file. 

143 :param hdf5_keys: Keys of arrays to be loaded. 

144 :param hdf5_path: Path within HDF5-file appended to all keys. 

145 :return: Loaded arrays. 

146 """ 

147 

148 if str(hdf5_keys) == hdf5_keys: 

149 hdf5_keys = [hdf5_keys] 

150 return_null_entry = True 

151 else: 

152 return_null_entry = False 

153 

154 hdf5_keys = [hdf5_path + hdf5_key for hdf5_key in hdf5_keys] 

155 

156 path = get_abs_path(file_name) 

157 

158 with h5py.File(path, mode="r") as hdf5_file: 

159 hdf5_data = [hdf5_file[hdf5_key][...] for hdf5_key in hdf5_keys] 

160 

161 if return_null_entry: 

162 hdf5_data = hdf5_data[0] 

163 

164 return hdf5_data 

165 

166 

167def get_abs_path(path): 

168 """ 

169 Get the absolute path, handling remote paths and environment variables. 

170 

171 :param path: relative or absolute path (can be Path object) 

172 :return: absolute path string 

173 """ 

174 path = str(path) # convert to string if it's a Path object 

175 if "@" in path and ":/" in path or os.path.isabs(path): 

176 abs_path = path 

177 

178 else: 

179 parent = os.environ["SUBMIT_DIR"] if "SUBMIT_DIR" in os.environ else os.getcwd() 

180 

181 abs_path = os.path.join(parent, path) 

182 

183 return abs_path 

184 

185 

186def robust_makedirs(path): 

187 """ 

188 Create directories, handling remote paths via SSH. 

189 

190 :param path: path to create (can be remote with user@host:path format) 

191 """ 

192 if is_remote(path): 

193 LOGGER.info(f"Creating remote directory {path}") 

194 host, path = path.split(":") 

195 cmd = f'ssh {host} "mkdir -p {path}"' 

196 subprocess.call(shlex.split(cmd)) 

197 

198 elif not os.path.isdir(path): 

199 os.makedirs(path) 

200 LOGGER.info(f"Created directory {path}") 

201 

202 

203def robust_copy( 

204 src, 

205 dst, 

206 n_max_connect=50, 

207 method="CopyGuardian", 

208 folder_with_many_files=False, 

209 **kwargs, 

210): 

211 """ 

212 Copy files/directories using the specified method. 

213 

214 :param src: Source file/directory. 

215 :param dst: Destination file/directory. 

216 :param n_max_connect: Maximum number of simultaneous connections. 

217 :param method: Method to use for copying. Can be "CopyGuardian" or "system_cp". 

218 :param folder_with_many_files: If True, the source is a folder with many files 

219 (only for CopyGuardian). 

220 :param kwargs: Additional arguments passed to the copy method. 

221 """ 

222 src = _ensure_list(src) 

223 if isinstance(dst, (list, tuple)) and len(dst) > 1: 

224 raise ValueError( 

225 f"Destination {dst} not supported. Multiple destinations for " 

226 "multiple sources not implemented." 

227 ) 

228 

229 # In case of a remote destination, rsync will create the directory itself 

230 if not is_remote(dst): 

231 robust_makedirs(os.path.dirname(dst)) 

232 

233 if method == "CopyGuardian": 

234 copy_with_copy_guardian( 

235 src, 

236 dst, 

237 n_max_connect=n_max_connect, 

238 folder_with_many_files=folder_with_many_files, 

239 ) 

240 

241 elif method == "system_cp": 

242 system_copy(sources=src, dest=dst, **kwargs) 

243 

244 else: 

245 raise Exception(f"Unknown copy method {method}") 

246 

247 

248def _ensure_list(obj): 

249 """ 

250 Ensure an object is a list. 

251 

252 :param obj: object to convert 

253 :return: list containing obj, or obj if already a list 

254 """ 

255 if not isinstance(obj, list): 

256 return [obj] 

257 return obj 

258 

259 

260def copy_with_copy_guardian( 

261 sources, 

262 destination, 

263 n_max_connect=10, 

264 timeout=1000, 

265 folder_with_many_files=False, 

266): 

267 """ 

268 Copy files/directories using the CopyGuardian. 

269 

270 :param sources: List of source files/directories. 

271 :param destination: Destination directory. 

272 :param n_max_connect: Maximum number of simultaneous connections. 

273 :param timeout: time in seconds to wait for a connection to become available 

274 :param folder_with_many_files: If True, the source is a folder with many files 

275 """ 

276 with copy_guardian.BoundedSemaphore(n_max_connect, timeout=timeout): 

277 for src in sources: 

278 LOGGER.info(f"Copying locally: {src} -> {destination}") 

279 if os.path.isdir(src): 

280 if folder_with_many_files: 

281 LOGGER.debug("Copying folder with many files") 

282 # robust_makedirs(destination) 

283 copy_guardian.copy_utils.copy_local_folder(src, destination) 

284 else: 

285 shutil.copytree(src, destination, dirs_exist_ok=True) 

286 elif os.path.isdir(destination): 

287 shutil.copy(src, destination) 

288 else: 

289 shutil.copyfile(src, destination) 

290 

291 

292def system_copy(sources, dest, args_str_cp=""): 

293 """ 

294 Copy files using the system cp command. 

295 

296 :param sources: list of source paths 

297 :param dest: destination path 

298 :param args_str_cp: additional arguments for cp command 

299 """ 

300 cmd = "cp -r " + args_str_cp 

301 for f in sources: 

302 cmd += f" {f} " 

303 cmd += f" {dest}" 

304 LOGGER.debug(f"Copying {len(sources)} files to {dest}") 

305 LOGGER.debug(cmd) 

306 os.system(cmd) 

307 

308 

309def is_remote(path): 

310 """ 

311 Check if a path is a remote path (user@host:/path format). 

312 

313 :param path: path to check (can be Path object) 

314 :return: True if remote, False otherwise 

315 """ 

316 path = str(path) # convert to string if it's a Path object 

317 return "@" in path and ":/" in path 

318 

319 

320def read_yaml(filename): 

321 """ 

322 Read a YAML file. 

323 

324 :param filename: path to YAML file 

325 :return: parsed YAML content 

326 """ 

327 with open(filename) as f: 

328 file = yaml.load(f, Loader=yaml.Loader) 

329 return file 

330 

331 

332def ensure_permissions(path, verb=False): 

333 """ 

334 Set file permissions to user rwx, group rx, others rx. 

335 

336 :param path: path to file or directory 

337 :param verb: if True, log the permission change 

338 """ 

339 val = stat.S_IRWXU | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH 

340 os.chmod(path, val) 

341 if verb: 

342 LOGGER.debug(f"Changed permissions for {path} to {oct(val)}")