Coverage for src / cosmic_toolbox / file_utils.py: 88%
132 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-31 12:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-31 12:38 +0000
1# Copyright (C) 2017 ETH Zurich
2# Cosmology Research Group
3# Author: Joerg Herbel
5"""
6File utilities for reading, writing, and copying files.
8Provides functions for:
9- Reading/writing pickle and HDF5 files with compression
10- Robust file/directory operations (makedirs, remove, copy)
11- Remote file operations via SSH/rsync
12- YAML file handling
14"""
16import os
17import pickle
18import shlex
19import shutil
20import stat
21import subprocess
23import copy_guardian
24import h5py
25import yaml
27from cosmic_toolbox.logger import get_logger
29DEFAULT_ROOT_PATH = ""
30LOGGER = get_logger(__file__)
33def robust_remove(path):
34 """
35 Remove a file or directory.
37 :param path: Path to the file or directory.
38 """
39 if is_remote(path):
40 LOGGER.info(f"Removing remote directory {path}")
41 host, path = path.split(":")
42 cmd = f'ssh {host} "rm -rf {path}"'
43 subprocess.call(shlex.split(cmd))
45 else:
46 if os.path.isfile(path):
47 os.remove(path)
49 elif os.path.isdir(path):
50 shutil.rmtree(path)
52 else:
53 LOGGER.warning(f"Cannot remove {path} because it does not exist")
56def write_to_pickle(filepath, obj, compression="none"):
57 """
58 Write an object to a pickle file.
60 :param filepath: Path to the pickle file.
61 :param obj: Object to write.
62 :param compression: Compression method to use. Can be "none", "lzf" or "bz2".
63 """
64 if compression.lower() == "none":
65 with open(filepath, "wb") as f:
66 pickle.dump(obj, f)
67 elif compression.lower() == "lzf":
68 import lzf
70 with lzf.open(filepath, "wb") as f:
71 pickle.dump(obj, f)
72 elif compression.lower() == "bz2":
73 import bz2
75 with bz2.open(filepath, "wb") as f:
76 pickle.dump(obj, f)
78 else:
79 raise Exception(f"uknown compression {compression} [none, lzf, bz2]")
82def read_from_pickle(filepath, compression="none"):
83 """
84 Read an object from a pickle file.
86 :param filepath: Path to the pickle file.
87 :param compression: Compression method to use. Can be "none", "lzf" or "bz2".
88 :return: Object read from the pickle file.
89 """
90 if compression.lower() == "none":
91 with open(filepath, "rb") as f:
92 obj = pickle.load(f)
93 elif compression.lower() == "lzf":
94 import lzf
96 with lzf.open(filepath, "rb") as f:
97 obj = pickle.load(f)
98 elif compression.lower() == "bz2":
99 import bz2
101 with bz2.open(filepath, "rb") as f:
102 obj = pickle.load(f)
104 else:
105 raise Exception(f"uknown compression {compression} [none, lzf, bz2]")
107 return obj
110def write_to_hdf(filepath, obj, name="data", **kwargs):
111 """
112 Write an object to an hdf5 file.
114 :param filepath: Path to the hdf5 file.
115 :param obj: Object to write.
116 :param name: Name of the dataset.
117 :param kwargs: Additional arguments passed to h5py.File.create_dataset.
118 """
120 with h5py.File(filepath, "w") as f:
121 f.create_dataset(name, data=obj, **kwargs)
124def read_from_hdf(filepath, name="data"):
125 """
126 Read an object from an hdf5 file.
128 :param filepath: Path to the hdf5 file.
129 :param name: Name of the dataset.
130 :return: Object read from the hdf5 file.
131 """
133 with h5py.File(filepath, "r") as f:
134 obj = f[name][:]
136 return obj
139def load_from_hdf5(file_name, hdf5_keys, hdf5_path=""):
140 """
141 Load data stored in a HDF5-file.
142 :param file_name: Name of the file.
143 :param hdf5_keys: Keys of arrays to be loaded.
144 :param hdf5_path: Path within HDF5-file appended to all keys.
145 :return: Loaded arrays.
146 """
148 if str(hdf5_keys) == hdf5_keys:
149 hdf5_keys = [hdf5_keys]
150 return_null_entry = True
151 else:
152 return_null_entry = False
154 hdf5_keys = [hdf5_path + hdf5_key for hdf5_key in hdf5_keys]
156 path = get_abs_path(file_name)
158 with h5py.File(path, mode="r") as hdf5_file:
159 hdf5_data = [hdf5_file[hdf5_key][...] for hdf5_key in hdf5_keys]
161 if return_null_entry:
162 hdf5_data = hdf5_data[0]
164 return hdf5_data
167def get_abs_path(path):
168 """
169 Get the absolute path, handling remote paths and environment variables.
171 :param path: relative or absolute path (can be Path object)
172 :return: absolute path string
173 """
174 path = str(path) # convert to string if it's a Path object
175 if "@" in path and ":/" in path or os.path.isabs(path):
176 abs_path = path
178 else:
179 parent = os.environ["SUBMIT_DIR"] if "SUBMIT_DIR" in os.environ else os.getcwd()
181 abs_path = os.path.join(parent, path)
183 return abs_path
186def robust_makedirs(path):
187 """
188 Create directories, handling remote paths via SSH.
190 :param path: path to create (can be remote with user@host:path format)
191 """
192 if is_remote(path):
193 LOGGER.info(f"Creating remote directory {path}")
194 host, path = path.split(":")
195 cmd = f'ssh {host} "mkdir -p {path}"'
196 subprocess.call(shlex.split(cmd))
198 elif not os.path.isdir(path):
199 os.makedirs(path)
200 LOGGER.info(f"Created directory {path}")
203def robust_copy(
204 src,
205 dst,
206 n_max_connect=50,
207 method="CopyGuardian",
208 folder_with_many_files=False,
209 **kwargs,
210):
211 """
212 Copy files/directories using the specified method.
214 :param src: Source file/directory.
215 :param dst: Destination file/directory.
216 :param n_max_connect: Maximum number of simultaneous connections.
217 :param method: Method to use for copying. Can be "CopyGuardian" or "system_cp".
218 :param folder_with_many_files: If True, the source is a folder with many files
219 (only for CopyGuardian).
220 :param kwargs: Additional arguments passed to the copy method.
221 """
222 src = _ensure_list(src)
223 if isinstance(dst, (list, tuple)) and len(dst) > 1:
224 raise ValueError(
225 f"Destination {dst} not supported. Multiple destinations for "
226 "multiple sources not implemented."
227 )
229 # In case of a remote destination, rsync will create the directory itself
230 if not is_remote(dst):
231 robust_makedirs(os.path.dirname(dst))
233 if method == "CopyGuardian":
234 copy_with_copy_guardian(
235 src,
236 dst,
237 n_max_connect=n_max_connect,
238 folder_with_many_files=folder_with_many_files,
239 )
241 elif method == "system_cp":
242 system_copy(sources=src, dest=dst, **kwargs)
244 else:
245 raise Exception(f"Unknown copy method {method}")
248def _ensure_list(obj):
249 """
250 Ensure an object is a list.
252 :param obj: object to convert
253 :return: list containing obj, or obj if already a list
254 """
255 if not isinstance(obj, list):
256 return [obj]
257 return obj
260def copy_with_copy_guardian(
261 sources,
262 destination,
263 n_max_connect=10,
264 timeout=1000,
265 folder_with_many_files=False,
266):
267 """
268 Copy files/directories using the CopyGuardian.
270 :param sources: List of source files/directories.
271 :param destination: Destination directory.
272 :param n_max_connect: Maximum number of simultaneous connections.
273 :param timeout: time in seconds to wait for a connection to become available
274 :param folder_with_many_files: If True, the source is a folder with many files
275 """
276 with copy_guardian.BoundedSemaphore(n_max_connect, timeout=timeout):
277 for src in sources:
278 LOGGER.info(f"Copying locally: {src} -> {destination}")
279 if os.path.isdir(src):
280 if folder_with_many_files:
281 LOGGER.debug("Copying folder with many files")
282 # robust_makedirs(destination)
283 copy_guardian.copy_utils.copy_local_folder(src, destination)
284 else:
285 shutil.copytree(src, destination, dirs_exist_ok=True)
286 elif os.path.isdir(destination):
287 shutil.copy(src, destination)
288 else:
289 shutil.copyfile(src, destination)
292def system_copy(sources, dest, args_str_cp=""):
293 """
294 Copy files using the system cp command.
296 :param sources: list of source paths
297 :param dest: destination path
298 :param args_str_cp: additional arguments for cp command
299 """
300 cmd = "cp -r " + args_str_cp
301 for f in sources:
302 cmd += f" {f} "
303 cmd += f" {dest}"
304 LOGGER.debug(f"Copying {len(sources)} files to {dest}")
305 LOGGER.debug(cmd)
306 os.system(cmd)
309def is_remote(path):
310 """
311 Check if a path is a remote path (user@host:/path format).
313 :param path: path to check (can be Path object)
314 :return: True if remote, False otherwise
315 """
316 path = str(path) # convert to string if it's a Path object
317 return "@" in path and ":/" in path
320def read_yaml(filename):
321 """
322 Read a YAML file.
324 :param filename: path to YAML file
325 :return: parsed YAML content
326 """
327 with open(filename) as f:
328 file = yaml.load(f, Loader=yaml.Loader)
329 return file
332def ensure_permissions(path, verb=False):
333 """
334 Set file permissions to user rwx, group rx, others rx.
336 :param path: path to file or directory
337 :param verb: if True, log the permission change
338 """
339 val = stat.S_IRWXU | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH
340 os.chmod(path, val)
341 if verb:
342 LOGGER.debug(f"Changed permissions for {path} to {oct(val)}")