diff --git a/draco/core/io.py b/draco/core/io.py index af88cd35e..afa783834 100644 --- a/draco/core/io.py +++ b/draco/core/io.py @@ -23,9 +23,10 @@ files: ['file1.h5', 'file2.h5'] """ import os.path +import shutil +from typing import Union, Dict, List import numpy as np -from typing import Union, Dict, List from yaml import dump as yamldump from caput import pipeline, config, fileformats, memh5, truncate @@ -111,10 +112,10 @@ def _list_or_glob(files: Union[str, List[str]]) -> List[str]: if isinstance(files, str): # Check that it exists and is a file (or dir if zarr format) - if files.endswith("zarr"): + if files.endswith(".zarr"): if not os.path.isdir(files): raise ConfigError( - f"Expecting a zarr container, but directory not found: {files}" + f"Expected a zarr directory store, but directory not found: {files}" ) return [files] else: @@ -873,6 +874,112 @@ def process(self, data): return data +class ZipZarrContainers(task.SingleTask): + """Zip up a Zarr container into a single file. + + This is useful to save on file quota and speed up IO by combining the chunk + data into a single file. Note that the file cannot really be updated after + this process has been performed. + + As this process is IO limited in most cases, it will attempt to parallelise + the compression across different distinct nodes. That means at most only + one rank per node will participate. + + Attributes + ---------- + containers : list + The names of the Zarr containers to compress. The zipped files will + have the same names with `.zip` appended. + remove : bool + Remove the original data when finished. Defaults to True. + """ + + containers = config.Property(proptype=list) + remove = config.Property(proptype=bool, default=True) + + def setup(self, _=None): + """Setup the task. + + This routine does nothing at all with the input, but it means the + process won't run until the (optional) requirement is received. This + can be used to delay evaluation until you know that all the files are + available. + """ + import socket + + # See if we can find 7z + path_7z = shutil.which("7z") + if path_7z is None: + raise RuntimeError("Could not find 7z on the PATH") + self._path_7z = path_7z + + # Get the rank -> hostname mapping for all ranks + my_host = socket.gethostname() + my_rank = self.comm.rank + all_host_ranks = self.comm.allgather((my_host, my_rank)) + + # Identify the lowest rank running on each node + unique_hosts = {} + for host, rank in all_host_ranks: + if host not in unique_hosts: + unique_hosts[host] = rank + else: + if unique_hosts[host] > rank: + unique_hosts[host] = rank + + self._num_hosts = len(unique_hosts) + + # Figure out if this rank is one that needs to do anything + if unique_hosts[my_host] != my_rank: + # This is not the lowest rank on the host, so we don't do anything + self._host_rank = None + else: + # This is the lowest rank, so find where we are in the sorted list of all hosts + self._host_rank = sorted(unique_hosts).index(my_host) + self.log.debug(f"Lowest rank on {my_host}") + + def process(self): + """Compress the listed zarr containers. + + Only the lowest rank on each node will participate. + """ + + if self._host_rank is not None: + import subprocess + + # Get the set of containers this rank is responsible for compressing + my_containers = self.containers[self._host_rank :: self._num_hosts] + + for container in my_containers: + + self.log.info(f"Zipping {container}") + + if not container.endswith(".zarr") or not os.path.isdir(container): + raise ValueError(f"{container} is not a valid .zarr directory") + + # Run 7z to zip up the file + dest_file = container + ".zip" + src_dir = container + "/." + command = [self._path_7z, "a", "-tzip", "-mx=0", dest_file, src_dir] + status = subprocess.run(command, capture_output=True) + + if status.returncode != 0: + self.log.debug("Error occurred while zipping. Debug logs follow...") + self.log.debug(f"stdout={status.stdout}") + self.log.debug(f"stderr={status.stderr}") + raise RuntimeError(f"Error occurred while zipping {container}.") + + self.log.info(f"Done zipping. Generated {dest_file}.") + + if self.remove: + shutil.rmtree(container) + self.log.info(f"Removed original container {container}.") + + self.comm.Barrier() + + raise pipeline.PipelineStopIteration + + class SaveModuleVersions(task.SingleTask): """Write module versions to a YAML file.