From 7149b4e81cd6e9d69ba3ac6c85521f26a26d74c1 Mon Sep 17 00:00:00 2001 From: Rick Nitsche Date: Wed, 28 Apr 2021 15:01:11 -0700 Subject: [PATCH 1/7] feat(memh5): zarr compression support Add support for serialising data into Zarr format files. This enables compression for distributed containers. Co-authored-by: Anja Kefala Co-authored-by: Tristan Pinsonneault-Marotte Co-authored-by: Richard Shaw --- .github/workflows/main.yml | 16 +- caput/config.py | 56 ++ caput/fileformats.py | 300 +++++++++ caput/memh5.py | 640 ++++++++++++++++--- caput/misc.py | 4 +- caput/mpiarray.py | 288 +++++++-- caput/pipeline.py | 103 +++- caput/tests/conftest.py | 51 +- caput/tests/test_lint.py | 8 +- caput/tests/test_memh5.py | 814 +++++++++++++++---------- caput/tests/test_memh5_parallel.py | 398 ++++++------ caput/tests/test_mpiarray.py | 561 ++++++++--------- caput/tests/test_selection.py | 61 +- caput/tests/test_selection_parallel.py | 55 +- caput/tod.py | 10 +- doc/conf.py | 2 +- requirements.txt | 12 +- setup.py | 10 +- 18 files changed, 2394 insertions(+), 995 deletions(-) create mode 100644 caput/fileformats.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b8fa6a2c..7cc2a13d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -21,13 +21,14 @@ jobs: - name: Install apt dependencies run: | - sudo apt-get update && sudo apt-get install -y libopenmpi-dev openmpi-bin + sudo apt-get install -y libopenmpi-dev openmpi-bin libhdf5-serial-dev - name: Install pip dependencies run: | - pip install pylint==2.7.0 pylint-ignore flake8 pytest black mpi4py pyinstrument psutil + pip install pylint==2.7.0 pylint-ignore flake8 pytest black mpi4py pyinstrument pytest-lazy-fixture pip install -r requirements.txt python setup.py develop + pip install .[compression] - name: Run flake8 run: flake8 --show-source --ignore=E501,E741,E203,W503,E266 caput @@ -47,7 +48,7 @@ jobs: - name: Install apt dependencies run: | - sudo apt-get update && sudo apt-get install -y libopenmpi-dev openmpi-bin + sudo apt-get install -y libopenmpi-dev openmpi-bin libhdf5-serial-dev - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v1 @@ -56,8 +57,11 @@ jobs: - name: Install pip dependencies run: | + pip install --no-binary=h5py h5py pip install -r requirements.txt - pip install mpi4py pytest psutil + pip install zarr==2.8.1 + pip install mpi4py numcodecs==0.7.3 bitshuffle@git+https://github.com/kiyo-masui/bitshuffle.git psutil + pip install pytest pytest-lazy-fixture python setup.py develop - name: Run serial tests @@ -82,6 +86,10 @@ jobs: with: python-version: 3.9 + - name: Install apt dependencies + run: | + sudo apt-get install -y libhdf5-serial-dev + - name: Install pip dependencies run: | pip install -r requirements.txt diff --git a/caput/config.py b/caput/config.py index bb579fed..58bc3212 100644 --- a/caput/config.py +++ b/caput/config.py @@ -482,6 +482,62 @@ def _prop(config): return prop +def file_format(default=None): + """A property type that accepts only "zarr", or "hdf5". + + Returns the selected `caput.fileformat.FileFormat` subclass or `caput.fileformats.HDF5` if `value == default`. + + Parameters + ---------- + default : optional + The optional default value. + + Returns + ------- + prop : Property + A property instance setup to validate a file format. + + Raises + ------ + ValueError + If the default value is not `"hdf5"` or `"zarr"`. + + Examples + -------- + Should be used like:: + + class Project: + + mode = file_format(default='zarr') + """ + options = ("hdf5", "zarr") + + def _prop(val): + from . import fileformats + + if val is None: + return None + + if not isinstance(val, str): + CaputConfigError( + f"Input {repr(val)} is of type {type(val).__name__} (expected str or None)." + ) + + val = val.lower() + if val == "hdf5": + return fileformats.HDF5 + if val == "zarr": + return fileformats.Zarr + raise CaputConfigError(f"Input {repr(val)} needs to be one of {options})") + + if default is not None and ( + (not isinstance(default, str)) or (default.lower() not in options) + ): + raise CaputConfigError(f"Default value {repr(default)} must be in {options}") + + return Property(proptype=_prop, default=default) + + class _line_dict(dict): """A private dict subclass that also stores line numbers for debugging.""" diff --git a/caput/fileformats.py b/caput/fileformats.py new file mode 100644 index 00000000..4e632daf --- /dev/null +++ b/caput/fileformats.py @@ -0,0 +1,300 @@ +"""Interface for file formats supported by caput: HDF5 and Zarr.""" +import logging +import os +import shutil + +import h5py + +logger = logging.getLogger(__name__) + +try: + import zarr +except ImportError as err: + logger.info(f"zarr support disabled. Install zarr to change this: {err}") + zarr_available = False +else: + zarr_available = True + +try: + from bitshuffle.h5 import H5FILTER, H5_COMPRESS_LZ4 + import numcodecs +except ModuleNotFoundError as e: + logger.debug( + f"Install with 'compression' extra_require to use bitshuffle/numcodecs compression filters.: {e}" + ) + compression_enabled = False + H5FILTER, H5_COMPRESS_LZ4 = None, None +else: + compression_enabled = True + + +class FileFormat: + """Abstract base class for file formats supported by this module.""" + + module = None + + @staticmethod + def open(*args, **vargs): + """ + Open a file. + + Not implemented in base class + """ + raise NotImplementedError + + @staticmethod + def compression_kwargs(compression=None, compression_opts=None, compressor=None): + """ + Sort compression arguments in a format expected by file format module. + + Parameters + ---------- + compression : str or int + Name or identifier of HDF5 compression filter. + compression_opts + See HDF5 documentation for compression filters. + compressor : `numcodecs` compressor + As required by `zarr`. + + Returns + ------- + dict + Compression arguments as required by the file format module. + """ + if compressor and (compression or compression_opts): + raise ValueError( + f"Found more than one kind of compression args: compression ({compression}, {compression_opts}) " + f"and compressor {compressor}." + ) + + +class HDF5(FileFormat): + """Interface for using HDF5 file format from caput.""" + + module = h5py + + @staticmethod + def compression_enabled(): + """Disable compression and chunking due to bug: https://github.com/chime-experiment/Pipeline/issues/33""" + return False + + @staticmethod + def open(*args, **kwargs): + """Open an HDF5 file using h5py.""" + return h5py.File(*args, **kwargs) + + @staticmethod + def compression_kwargs(compression=None, compression_opts=None, compressor=None): + """Format compression arguments for h5py API.""" + super(HDF5, HDF5).compression_kwargs(compression, compression_opts, compressor) + if compressor: + raise NotImplementedError + if compression in ("bitshuffle", H5FILTER, str(H5FILTER)): + if not compression_enabled: + raise ValueError( + "Install with 'compression' extra_require to use bitshuffle/numcodecs compression filters." + ) + compression = H5FILTER + try: + blocksize, c = compression_opts + except ValueError as e: + raise ValueError( + f"Failed to interpret compression_opts: {e}\ncompression_opts: {compression_opts}." + ) from e + if blocksize is None: + blocksize = 0 + if c in (str(H5_COMPRESS_LZ4), "lz4"): + c = H5_COMPRESS_LZ4 + compression_opts = (blocksize, c) + + if compression is not None: + return {"compression": compression, "compression_opts": compression_opts} + return {} + + +class Zarr(FileFormat): + """Interface for using zarr file format from caput.""" + + if zarr_available: + module = zarr + else: + module = None + + @staticmethod + def open(*args, **kwargs): + """Open a zarr file.""" + if not zarr_available: + raise RuntimeError("Can't open zarr file. Please install zarr.") + return zarr.open_group(*args, **kwargs) + + @staticmethod + def compression_kwargs(compression=None, compression_opts=None, compressor=None): + """Format compression arguments for zarr API.""" + super(Zarr, Zarr).compression_kwargs(compression, compression_opts, compressor) + if compression: + if not compression_enabled: + raise ValueError( + "Install with 'compression' extra_require to use bitshuffle/numcodecs compression filters." + ) + if compression == "gzip": + return {"compressor": numcodecs.gzip.GZip(level=compression_opts)} + if compression in (H5FILTER, str(H5FILTER), "bitshuffle"): + try: + blocksize, c = compression_opts + except ValueError as e: + raise ValueError( + f"Failed to interpret compression_opts: {e}\ncompression_opts: {compression_opts}" + ) from e + if c in (H5_COMPRESS_LZ4, str(H5_COMPRESS_LZ4)): + c = "lz4" + if blocksize is None: + blocksize = 0 + return { + "compressor": numcodecs.Blosc( + c, + shuffle=numcodecs.blosc.BITSHUFFLE, + blocksize=int(blocksize) if blocksize is not None else None, + ) + } + else: + raise ValueError( + f"Compression filter not supported in zarr: {compression}" + ) + else: + return {"compressor": compressor} + + +class ZarrProcessSynchronizer: + """ + A context manager for Zarr's ProcessSynchronizer that removes the lock files when done. + + If an MPI communicator is supplied, only rank 0 will attempt to remove files. + + Parameters + ---------- + name : str + Name of the lockfile directory. + comm : + MPI communicator (optional). + """ + + def __init__(self, name, comm=None): + if not zarr_available: + raise RuntimeError( + "Can't use zarr process synchronizer. Please install zarr." + ) + self.name = name + self._comm = comm + + def __enter__(self): + return zarr.ProcessSynchronizer(self.name) + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._comm is None or self._comm.rank == 0: + remove_file_or_dir(self.name) + + +def remove_file_or_dir(name: str): + """ + Remove the file or directory with the given name. + + Parameters + ---------- + name : str + File or directory name to remove. + """ + if os.path.isdir(name): + try: + shutil.rmtree(name) + except FileNotFoundError: + pass + else: + try: + os.remove(name) + except FileNotFoundError: + pass + + +def guess_file_format(name, default=HDF5): + """ + Guess the file format from the file name. + + Parameters + ---------- + name : str + File name. + default : FileFormat or None + Fallback value if format can't be guessed. Default `fileformats.HDF5`. + + Returns + ------- + format : `FileFormat` + File format guessed. + """ + import pathlib + + if name.endswith(".zarr") or pathlib.Path(name).is_dir(): + return Zarr + if name.endswith(".h5") or name.endswith(".hdf5"): + return HDF5 + return default + + +def check_file_format(filename, file_format, data): + """ + Compare file format with guess from filename and data. Return concluded format. + + Parameters + ---------- + filename : str + File name. + file_format : FileFormat or None + File format. None if it should be guessed. + data : any + If this is an h5py.Group or zarr.Group, it will be used to guess or confirm the file format. + + Returns + ------- + file_format : HDF5 or Zarr + File format. + """ + + # check value + if file_format not in (None, HDF5, Zarr): + raise ValueError( + f"Unexpected value for : {file_format} " + f"(expected caput.fileformats.HDF5, caput.fileformats.Zarr or None)." + ) + + # guess file format from + if isinstance(data, h5py.Group): + file_format_guess_output = HDF5 + elif zarr_available and isinstance(data, zarr.Group): + file_format_guess_output = Zarr + else: + file_format_guess_output = None + + # guess file format from + file_format_guess_name = guess_file_format(filename, None) + + # make sure guesses don't mismatch and decide on the format + if ( + file_format_guess_output + and file_format_guess_name + and file_format_guess_name != file_format_guess_output + ): + raise ValueError( + f" ({file_format}) and ({filename}) don't seem to match." + ) + file_format_guess = ( + file_format_guess_output if file_format_guess_output else file_format_guess_name + ) + if file_format is None: + file_format = file_format_guess + elif file_format != file_format_guess: + raise ValueError( + f"Value of ({file_format}) doesn't match ({filename}) " + f"and type of data ({type(data).__name__})." + ) + + return file_format diff --git a/caput/memh5.py b/caput/memh5.py index a627875d..e864fc9b 100644 --- a/caput/memh5.py +++ b/caput/memh5.py @@ -60,6 +60,7 @@ import numpy as np import h5py +from . import fileformats from . import mpiutil from . import mpiarray from . import misc @@ -67,6 +68,13 @@ logger = logging.getLogger(__name__) +try: + import zarr +except ImportError as err: + logger.info(f"zarr support disabled. Install zarr to change this: {err}") + zarr_available = False +else: + zarr_available = True # Basic Classes # ------------- @@ -390,7 +398,7 @@ def from_group(cls, group): """Create a new instance by deep copying an existing group. Agnostic as to whether the group to be copied is a `MemGroup` or an - `h5py.Group` (which includes `hdf5.File` objects). + `h5py.Group` (which includes `h5py.File` and `zarr.File` objects). """ @@ -398,8 +406,14 @@ def from_group(cls, group): self = cls() deep_group_copy(group, self) return self + elif isinstance(group, (str, bytes)): + file_format = fileformats.guess_file_format(group) + return cls.from_file(group, file_format=file_format) else: - return cls.from_hdf5(group) + raise RuntimeError( + f"Can't create an instance from type {type(group).__name__} " + f"(expected MemGroup, str or bytes)." + ) @classmethod def from_hdf5( @@ -411,7 +425,7 @@ def from_hdf5( selections=None, convert_dataset_strings=False, convert_attribute_strings=True, - **kwargs + **kwargs, ): """Create a new instance by copying from an hdf5 group. @@ -443,6 +457,63 @@ def from_hdf5( Root group of loaded file. """ + return cls.from_file( + filename, + distributed, + hints, + comm, + selections, + convert_dataset_strings, + convert_attribute_strings, + file_format=fileformats.HDF5, + **kwargs, + ) + + @classmethod + def from_file( + cls, + filename, + distributed=False, + hints=True, + comm=None, + selections=None, + convert_dataset_strings=False, + convert_attribute_strings=True, + file_format=fileformats.HDF5, + **kwargs, + ): + """Create a new instance by copying from a file group. + + Any keyword arguments are passed on to the constructor for `h5py.File` or `zarr.File`. + + Parameters + ---------- + filename : string + Name of file to load. + distributed : boolean, optional + Whether to load file in distributed mode. + hints : boolean, optional + If in distributed mode use hints to determine whether datasets are + distributed or not. + comm : MPI.Comm, optional + MPI communicator to distributed over. If :obj:`None` use + :obj:`MPI.COMM_WORLD`. + selections : dict + If this is not None, it should map dataset names to axis selections as valid + numpy indexes. + convert_attribute_strings : bool, optional + Try and convert attribute string types to unicode. Default is `True`. + convert_dataset_strings : bool, optional + Try and convert dataset string types to unicode. Default is `False`. + file_format : `fileformats.FileFormat` + File format to use. Default `fileformats.HDF5`. + + Returns + ------- + group : memh5.Group + Root group of loaded file. + """ + if comm is None: comm = mpiutil.world @@ -456,7 +527,7 @@ def from_hdf5( if not distributed or not hints: kwargs["mode"] = "r" - with h5py.File(filename, **kwargs) as f: + with file_format.open(filename, **kwargs) as f: self = cls(distributed=distributed, comm=comm) deep_group_copy( f, @@ -464,15 +535,17 @@ def from_hdf5( selections=selections, convert_attribute_strings=convert_attribute_strings, convert_dataset_strings=convert_dataset_strings, + file_format=file_format, ) else: - self = _distributed_group_from_hdf5( + self = _distributed_group_from_file( filename, comm=comm, hints=hints, selections=selections, convert_attribute_strings=convert_attribute_strings, convert_dataset_strings=convert_dataset_strings, + file_format=file_format, ) return self @@ -484,7 +557,7 @@ def to_hdf5( hints=True, convert_attribute_strings=True, convert_dataset_strings=False, - **kwargs + **kwargs, ): """Replicate object on disk in an hdf5 file. @@ -502,17 +575,58 @@ def to_hdf5( understands. Default is `True`. convert_dataset_strings : bool, optional Try and convert dataset string types to bytestrings. Default is `False`. + """ + self.to_file( + filename, + mode, + hints, + convert_attribute_strings, + convert_dataset_strings, + fileformats.HDF5, + **kwargs, + ) + + def to_file( + self, + filename, + mode="w", + hints=True, + convert_attribute_strings=True, + convert_dataset_strings=False, + file_format=fileformats.HDF5, + **kwargs, + ): + """Replicate object on disk in an hdf5 or zarr file. + + Any keyword arguments are passed on to the constructor for `h5py.File` or `zarr.File`. + + Parameters + ---------- + filename : str + File to save into. + hints : boolean, optional + Whether to write hints into the file that described whether datasets + are distributed, or not. + convert_attribute_strings : bool, optional + Try and convert attribute string types to a unicode type that HDF5 + understands. Default is `True`. + convert_dataset_strings : bool, optional + Try and convert dataset string types to bytestrings. Default is `False`. + file_format : `fileformats.FileFormat` + File format to use. Default `fileformats.HDF5`. + """ if not self.distributed: - with h5py.File(filename, mode, **kwargs) as f: + with file_format.open(filename, mode, **kwargs) as f: deep_group_copy( self, f, convert_attribute_strings=convert_attribute_strings, convert_dataset_strings=convert_dataset_strings, + file_format=file_format, ) - else: + elif file_format == fileformats.HDF5: if h5py.get_config().mpi: _distributed_group_to_hdf5_parallel( self, @@ -529,6 +643,14 @@ def to_hdf5( convert_attribute_strings=convert_attribute_strings, convert_dataset_strings=convert_dataset_strings, ) + else: + _distributed_group_to_zarr( + self, + filename, + mode, + convert_attribute_strings=convert_attribute_strings, + convert_dataset_strings=convert_dataset_strings, + ) def create_group(self, name): """Create a group within the storage tree.""" @@ -574,7 +696,7 @@ def create_dataset( chunks=None, compression=None, compression_opts=None, - **kwargs + **kwargs, ): """Create a new dataset. @@ -593,6 +715,11 @@ def create_dataset( distributed_axis : int, optional Axis to distribute the data over. If specified with initialisation data this will cause create a copy with the correct distribution. + compression : str or int + Name or identifier of HDF5 or Zarr compression filter. + compression_opts + See HDF5 and Zarr documentation for compression filters. + Compression options for the dataset. Returns ------- @@ -895,22 +1022,48 @@ def resize(self): @property def shape(self): + """ + Shape of the dataset. + + Not implemented in base class. + """ raise NotImplementedError("Not implemented in base class.") @property def dtype(self): + """ + numpy data type of the dataset. + + Not implemented in base class. + """ raise NotImplementedError("Not implemented in base class.") @property def chunks(self): + """ + Chunk shape of the dataset. + + Not implemented in base class. + """ raise NotImplementedError("Not implemented in base class.") @property def compression(self): + """ + Name or identifier of HDF5 compression filter for the dataset. + + Not implemented in base class. + """ raise NotImplementedError("Not implemented in base class.") @property def compression_opts(self): + """ + Compression options for the dataset. + + See HDF5 documentation for compression filters. + Not implemented in base class. + """ raise NotImplementedError("Not implemented in base class.") def __getitem__(self, obj): @@ -951,7 +1104,7 @@ def __init__( chunks=None, compression=None, compression_opts=None, - **kwargs + **kwargs, ): super().__init__(**kwargs) @@ -970,6 +1123,11 @@ def from_numpy_array( ---------- data : np.ndarray Array to initialise from. + compression : str or int + Name or identifier of HDF5 or Zarr compression filter. + compression_opts + See HDF5 and Zarr documentation for compression filters. + Compression options for the dataset. Returns ------- @@ -1024,7 +1182,17 @@ def chunks(self): @chunks.setter def chunks(self, val): - self._chunks = val + if val is None: + chunks = val + elif len(val) != len(self.shape): + raise ValueError( + f"Chunk size {val} is not compatible with dataset shape {self.shape}." + ) + else: + chunks = () + for i, l in enumerate(self.shape): + chunks += (min(val[i], l),) + self._chunks = chunks @property def compression(self): @@ -1105,7 +1273,7 @@ def __init__( chunks=None, compression=None, compression_opts=None, - **kwargs + **kwargs, ): super().__init__(**kwargs) @@ -1152,10 +1320,20 @@ def shape(self): @property def global_shape(self): + """ + Global shape of the distributed dataset. + + The shape of the whole array that is distributed between multiple nodes. + """ return self._data.global_shape @property def local_shape(self): + """ + Local shape of the distributed dataset. + + The shape of the part of the distributed array that is allocated to *this* node. + """ return self._data.local_shape @property @@ -1164,15 +1342,27 @@ def local_offset(self): @property def dtype(self): + """The numpy data type of the dataset""" return self._data.dtype @property def chunks(self): + """The chunk shape of the dataset.""" return self._chunks @chunks.setter def chunks(self, val): - self._chunks = val + if val is None: + chunks = val + elif len(val) != len(self.shape): + raise ValueError( + f"Chunk size {val} is not compatible with dataset shape {self.shape}." + ) + else: + chunks = () + for i, l in enumerate(self.shape): + chunks += (min(val[i], l),) + self._chunks = chunks @property def compression(self): @@ -1192,6 +1382,7 @@ def compression_opts(self, val): @property def distributed_axis(self): + """The index of the axis over which this dataset is distributed.""" return self._data.axis @property @@ -1282,9 +1473,11 @@ class MemDiskGroup(_BaseGroup): detect_subclass: boolean, optional If *data_group* is specified, whether to inspect for a '__memh5_subclass' attribute which specifies a subclass to return. + file_format : `fileformats.FileFormat` + File format to use. File format will be guessed if not supplied. Default `None`. """ - def __init__(self, data_group=None, distributed=False, comm=None): + def __init__(self, data_group=None, distributed=False, comm=None, file_format=None): toclose = False @@ -1306,14 +1499,20 @@ def __init__(self, data_group=None, distributed=False, comm=None): # Otherwise, presume it is an HDF5 Group-like object (which includes # MemGroup and h5py.Group). else: - data_group, toclose = get_h5py_File(data_group, mode="a") - - if distributed and isinstance(data_group, h5py.Group): - raise ValueError( - "Distributed MemDiskGroup cannot be created around h5py objects." + data_group, toclose = get_file( + data_group, mode="a", file_format=file_format ) + # Zarr arrays are automatically flushed and closed + toclose = False if file_format == fileformats.HDF5 else toclose + # Check the distribution settings - elif distributed: + if distributed: + if isinstance(data_group, h5py.Group) or ( + zarr_available and isinstance(data_group, zarr.Group) + ): + raise ValueError( + "Distributed MemDiskGroup cannot be created around h5py or zarr objects." + ) # Check parallel distribution is the same if not data_group.distributed: raise ValueError( @@ -1404,6 +1603,9 @@ def close(self): if self.ondisk and hasattr(self, "_toclose") and self._toclose: self._storage_root.close() + if hasattr(self, "_lockfile") and (self.comm is None or self.comm.rank is None): + fileformats.remove_file_or_dir(self._lockfile) + def __getitem__(self, name): """Retrieve an object. @@ -1440,8 +1642,9 @@ def __iter__(self): @property def ondisk(self): """Whether the data is stored on disk as opposed to in memory.""" - return hasattr(self, "_storage_root") and isinstance( - self._storage_root, h5py.File + return hasattr(self, "_storage_root") and ( + isinstance(self._storage_root, h5py.File) + or (zarr_available and isinstance(self._storage_root, zarr.Group)) ) @classmethod @@ -1475,7 +1678,8 @@ def from_file( detect_subclass=True, convert_attribute_strings=None, convert_dataset_strings=None, - **kwargs + file_format=fileformats.HDF5, + **kwargs, ): """Create data object from analysis hdf5 file, store in memory or on disk. @@ -1512,10 +1716,14 @@ def from_file( _sel : list or slice Axis selections can be given to only read a subset of the containers. A slice can be given, or a list of specific array indices for that axis. + file_format : `fileformats.FileFormat` + File format to use. Default `fileformats.HDF5`. **kwargs : any other arguments Any additional keyword arguments are passed to :class:`h5py.File`'s constructor if *file_* is a filename and silently ignored otherwise. """ + if file_format == fileformats.Zarr and not zarr_available: + raise RuntimeError("Unable to read zarr file, please install zarr.") # Get a value for the conversion parameters, looking up on the class type if # not supplied @@ -1524,8 +1732,12 @@ def from_file( if convert_dataset_strings is None: convert_dataset_strings = getattr(cls, "convert_dataset_strings", False) + lockfile = None + if not ondisk: - if isinstance(file_, h5py.Group): + if (zarr_available and isinstance(file_, zarr.Group)) or isinstance( + file_, h5py.Group + ): file_ = file_.filename if "mode" in kwargs: @@ -1547,14 +1759,15 @@ def from_file( # Map selections to datasets sel = cls._make_selections(sel_args) - data = MemGroup.from_hdf5( + data = MemGroup.from_file( file_, distributed=distributed, comm=comm, selections=sel, convert_attribute_strings=convert_attribute_strings, convert_dataset_strings=convert_dataset_strings, - **kwargs + file_format=file_format, + **kwargs, ) toclose = False else: @@ -1564,8 +1777,11 @@ def from_file( toclose = False else: kwargs.setdefault("mode", "a") - data = h5py.File(file_, **kwargs) - toclose = True + if distributed and file_format == fileformats.Zarr: + lockfile = f"{file_}.sync" + kwargs["synchronizer"] = zarr.ProcessSynchronizer(lockfile) + data = file_format.open(file_, **kwargs) + toclose = file_format == fileformats.HDF5 # Here we explicitly avoid calling __init__ on any derived class. Like # with a pickle we want to restore the saved state only. @@ -1575,11 +1791,16 @@ def from_file( self._finish_setup() self._toclose = toclose + + if lockfile is not None: + self._comm = comm + self._lockfile = lockfile return self # Methods for manipulating and building the class. # - def group_name_allowed(self, name): + @staticmethod + def group_name_allowed(name): """Used by subclasses to restrict creation of and access to groups. This method is called by :meth:`create_group`, :meth:`require_group`, @@ -1603,7 +1824,8 @@ def group_name_allowed(self, name): """ return True - def dataset_name_allowed(self, name): + @staticmethod + def dataset_name_allowed(name): """Used by subclasses to restrict creation of and access to datasets. This method is called by :meth:`create_dataset`, @@ -1663,7 +1885,7 @@ def dataset_common_to_distributed(self, name, distributed_axis=0): return self._data.dataset_common_to_distributed(name, distributed_axis) else: raise RuntimeError( - "Can not convert a h5py dataset %s to distributed" % name + "Can not convert a h5py or zarr dataset %s to distributed" % name ) def dataset_distributed_to_common(self, name): @@ -1683,7 +1905,7 @@ def dataset_distributed_to_common(self, name): return self._data.dataset_distributed_to_common(name) else: raise RuntimeError( - "Can not convert a h5py dataset %s to distributed" % name + "Can not convert a h5py or zarr dataset %s to distributed" % name ) def create_group(self, name): @@ -1704,8 +1926,23 @@ def to_memory(self): else: return self.__class__.from_file(self._data) - def to_disk(self, filename, **kwargs): - """Return a version of this data that lives on disk.""" + def to_disk(self, filename, file_format=fileformats.HDF5, **kwargs): + """ + Return a version of this data that lives on disk. + + Parameters + ---------- + filename : str + File name. + file_format : `fileformats.FileFormat` + File format to use. Default `fileformats.HDF5`. + **kwargs + Keyword arguments passed through to the file creating, e.g. `mode`. + + Returns + ------- + Instance of this data object that is written to disk. + """ if not isinstance(self._data, MemGroup): msg = "This data already lives on disk. Copying to new file anyway." @@ -1715,8 +1952,10 @@ def to_disk(self, filename, **kwargs): "Cannot run to_disk on a distributed object. Try running save instead." ) - self.save(filename) - return self.__class__.from_file(filename, ondisk=True, **kwargs) + self.save(filename, file_format=file_format) + return self.__class__.from_file( + filename, ondisk=True, file_format=file_format, **kwargs + ) def flush(self): """Flush the buffers of the underlying hdf5 file if on disk.""" @@ -1729,9 +1968,10 @@ def save( filename, convert_attribute_strings=None, convert_dataset_strings=None, - **kwargs + file_format=fileformats.HDF5, + **kwargs, ): - """Save data to hdf5 file. + """Save data to hdf5/zarr file. Parameters ---------- @@ -1745,9 +1985,13 @@ def save( Try and convert dataset string types to bytestrings before saving to HDF5. If not specified, look up the name as a class attribute to find a default, and otherwise use `False`. + file_format : `fileformats.FileFormat` + File format to use. Default `fileformats.HDF5`. **kwargs Keyword arguments passed through to the file creating, e.g. `mode`. """ + if file_format == fileformats.Zarr and not zarr_available: + raise RuntimeError("Unable to write to zarr file, please install zarr.") # Get a value for the conversion parameters, looking up on the instance if # not supplied @@ -1763,15 +2007,18 @@ def save( self.attrs["__memh5_subclass"] = clspath - if isinstance(self._data, h5py.File): - with h5py.File(filename, **kwargs) as f: + if (zarr_available and isinstance(self._data, zarr.Group)) or isinstance( + self._data, h5py.File + ): + with file_format.open(filename, **kwargs) as f: deep_group_copy(self._data, f) else: - self._data.to_hdf5( + self._data.to_file( filename, convert_attribute_strings=convert_attribute_strings, convert_dataset_strings=convert_dataset_strings, - **kwargs + file_format=file_format, + **kwargs, ) @@ -2046,18 +2293,25 @@ def is_group(obj): def get_h5py_File(f, **kwargs): - """Checks if input is an `h5py.File` or filename and returns the former. + """Convenience function in order to not break old functionality.""" + return get_file(f, file_format=fileformats.HDF5, **kwargs) + + +def get_file(f, file_format=None, **kwargs): + """Checks if input is a `zarr`/`h5py.File` or filename and returns the former. Parameters ---------- - f : h5py Group or filename string + f : h5py/zarr Group or filename string + file_format : `fileformats.FileFormat` + File format to use. File format will be guessed if not supplied. Default `None`. **kwargs : all keyword arguments - Passed to :class:`h5py.File` constructor. If `f` is already an open file, + Passed to :class:`h5py.File` constructor or `zarr.open_group`. If `f` is already an open file, silently ignores all keywords. Returns ------- - f : hdf5 group + f : hdf5 or zarr group opened : bool Whether the a file was opened or not (i.e. was already open). @@ -2066,28 +2320,28 @@ def get_h5py_File(f, **kwargs): # Figure out if F is a file or a filename, and whether the file should be # closed. if is_group(f): - opened = False - # if kwargs: - # msg = "Got some keyword arguments but File is alrady open." - # warnings.warn(msg) + return f, False else: - opened = True + if file_format is None: + file_format = fileformats.guess_file_format(f) + if file_format == fileformats.Zarr and not zarr_available: + raise RuntimeError("Unable to open zarr file. Please install zarr.") try: - f = h5py.File(f, **kwargs) + f = file_format.open(f, **kwargs) except IOError as e: msg = "Opening file %s caused an error: " % str(f) raise IOError(msg + str(e)) from e - return f, opened + return f, True def copyattrs(a1, a2, convert_strings=False): - """Copy attributes from one h5py/memh5 attribute object to another. + """Copy attributes from one h5py/zarr/memh5 attribute object to another. Parameters ---------- - a1 : h5py/memh5 object + a1 : h5py/zarr/memh5 object Attributes to copy from. - a1 : h5py/memh5 object + a1 : h5py/zarr/memh5 object Attributes to copy into. convert_strings : bool, optional Convert string attributes (or lists/arrays of them) to ensure that they are @@ -2147,7 +2401,14 @@ def default(self, o): # Let the default method raise the TypeError return json.JSONEncoder.default(self, o) - if isinstance(value, dict) and isinstance(a2, h5py.AttributeManager): + if ( + isinstance(value, (dict, np.ndarray, datetime.datetime)) + and zarr_available + and isinstance(a2, zarr.attrs.Attributes) + ) or ( + isinstance(value, (dict, datetime.datetime)) + and isinstance(a2, h5py.AttributeManager) + ): # Save to JSON converting datetimes. encoder = Memh5JSONEncoder() value = json_prefix + encoder.encode(value) @@ -2159,6 +2420,8 @@ def default(self, o): for key in sorted(a1): val = _map_unicode(a1[key]) val = _map_json(val) + if isinstance(val, np.generic): # zarr can't handle numpy types + val = val.item() a2[key] = val @@ -2168,6 +2431,7 @@ def deep_group_copy( selections=None, convert_dataset_strings=False, convert_attribute_strings=True, + file_format=fileformats.HDF5, ): """ Copy full data tree from one group to another. @@ -2185,9 +2449,9 @@ def deep_group_copy( Parameters ---------- - g1 : h5py.Group + g1 : h5py.Group or zarr.Group Deep copy from this group. - g2 : h5py.Group + g2 : h5py.Group or zarr.Group Deep copy to this group. selections : dict If this is not None, it should have a subset of the same hierarchical structure @@ -2198,6 +2462,8 @@ def deep_group_copy( unicode. convert_dataset_strings : bool, optional Convert strings within datasets to ensure that they are unicode. + file_format : `fileformats.FileFormat` + File format to use. Default `fileformats.HDF5`. """ copyattrs(g1.attrs, g2.attrs, convert_strings=convert_attribute_strings) @@ -2213,6 +2479,7 @@ def deep_group_copy( selections, convert_dataset_strings=convert_dataset_strings, convert_attribute_strings=convert_attribute_strings, + file_format=file_format, ) else: # look for selection for this dataset (also try withouth the leading "/") @@ -2223,30 +2490,53 @@ def deep_group_copy( except AttributeError: selection = slice(None) + # only the case if zarr is not installed + if file_format.module is None: + raise RuntimeError( + "Can't deep_group_copy zarr file. Please install zarr." + ) + if convert_dataset_strings: # Convert unicode strings back into ascii byte strings. This will break # if there are characters outside of the ascii range - if isinstance(g2, h5py.Group): + if isinstance(g2, file_format.module.Group): data = ensure_bytestring(entry[selection]) # Convert strings in an HDF5 dataset into unicode else: data = ensure_unicode(entry[selection]) - - elif isinstance(g2, h5py.Group): + elif isinstance(g2, file_format.module.Group): data = check_unicode(entry) data = data[selection] else: data = entry[selection] + # get compression options/chunking for this dataset + chunks = getattr(entry, "chunks", None) + compression = getattr(entry, "compression", None) + compression_opts = getattr(entry, "compression_opts", None) + + # TODO: Am I missing something or is this branch not necessary? + # I guess I'm still confused as to why a file_format is + # required even for the in-memory case + if isinstance(g2, file_format.module.Group): + compression_kwargs = file_format.compression_kwargs( + compression=compression, + compression_opts=compression_opts, + compressor=getattr(entry, "compressor", None), + ) + else: + # in-memory case; use HDF5 compression args format for this case + compression_kwargs = fileformats.HDF5.compression_kwargs( + compression=compression, compression_opts=compression_opts + ) g2.create_dataset( key, shape=data.shape, dtype=data.dtype, data=data, - chunks=entry.chunks, - compression=entry.compression, - compression_opts=entry.compression_opts, + chunks=chunks, + **compression_kwargs, ) copyattrs( entry.attrs, g2[key].attrs, convert_strings=convert_attribute_strings @@ -2276,7 +2566,7 @@ def _distributed_group_to_hdf5_serial( hints=True, convert_dataset_strings=False, convert_attribute_strings=True, - **kwargs + **kwargs, ): """Private routine to copy full data tree from distributed memh5 object into an HDF5 file. @@ -2294,8 +2584,7 @@ def _distributed_group_to_hdf5_serial( # Create group (or file) if comm.rank == 0: - # If this is the root group, create the file and copy the file level - # attrs + # If this is the root group, create the file and copy the file level attrs if group.name == "/": with h5py.File(fname, mode, **kwargs) as f: copyattrs( @@ -2330,7 +2619,7 @@ def _distributed_group_to_hdf5_serial( mode, convert_dataset_strings=convert_dataset_strings, convert_attribute_strings=convert_attribute_strings, - **kwargs + **kwargs, ) # Write out distributed datasets (only the data, the attributes are written below) @@ -2338,12 +2627,27 @@ def _distributed_group_to_hdf5_serial( arr = check_unicode(entry) + if fileformats.HDF5.compression_enabled(): + ( + chunks, + compression_kwargs, + ) = entry.chunks, fileformats.HDF5.compression_kwargs( + compression=entry.compression, + compression_opts=entry.compression_opts, + ) + else: + # disable compression if not enabled for HDF5 files + # https://github.com/chime-experiment/Pipeline/issues/33 + chunks, compression_kwargs = None, { + "compression": None, + "compression_opts": None, + } + arr.to_hdf5( fname, entry.name, - chunks=entry.chunks, - compression=entry.compression, - compression_opts=entry.compression_opts, + chunks=chunks, + **compression_kwargs, ) comm.Barrier() @@ -2365,12 +2669,15 @@ def _distributed_group_to_hdf5_serial( else: data = check_unicode(entry) + # allow chunks and compression bc serialised IO dset = f.create_dataset( entry.name, data=data, chunks=entry.chunks, - compression=entry.compression, - compression_opts=entry.compression_opts, + **fileformats.HDF5.compression_kwargs( + compression=entry.compression, + compression_opts=entry.compression_opts, + ), ) copyattrs( entry.attrs, @@ -2411,7 +2718,7 @@ def _distributed_group_to_hdf5_parallel( hints=True, convert_dataset_strings=False, convert_attribute_strings=True, - **kwargs + **_, ): """Private routine to copy full data tree from distributed memh5 object into an HDF5 file. @@ -2445,13 +2752,25 @@ def _copy_to_file(memgroup, h5group): data = check_unicode(item) # Write to file from MPIArray + if fileformats.HDF5.compression_enabled(): + chunks, compression, compression_opts = ( + item.chunks, + item.compression, + item.compression_opts, + ) + else: + # disable compression if not enabled for HDF5 files + # https://github.com/chime-experiment/Pipeline/issues/33 + chunks, compression, compression_opts = None, None, None + data.to_hdf5( h5group, key, - chunks=item.chunks, - compression=item.compression, - compression_opts=item.compression_opts, + chunks=chunks, + compression=compression, + compression_opts=compression_opts, ) + dset = h5group[key] if hints: @@ -2467,13 +2786,27 @@ def _copy_to_file(memgroup, h5group): else: data = check_unicode(item) + if fileformats.HDF5.compression_enabled(): + ( + chunks, + compression_kwargs, + ) = item.chunks, fileformats.HDF5.compression_kwargs( + item.compression, item.compression_opts + ) + else: + # disable compression if not enabled for HDF5 files + # https://github.com/chime-experiment/Pipeline/issues/33 + chunks, compression_kwargs = None, { + "compression": None, + "compression_opts": None, + } + dset = h5group.create_dataset( key, shape=data.shape, dtype=data.dtype, - chunks=item.chunks, - compression=item.compression, - compression_opts=item.compression_opts, + chunks=chunks, + **compression_kwargs, ) # Write common data from rank 0 @@ -2503,16 +2836,134 @@ def _copy_to_file(memgroup, h5group): group.comm.Barrier() -def _distributed_group_from_hdf5( +def _distributed_group_to_zarr( + group, fname, - comm=None, + mode, hints=True, convert_dataset_strings=False, convert_attribute_strings=True, - **kwargs + **_, +): + """Private routine to copy full data tree from distributed memh5 object into a Zarr file. + + This paralellizes all IO.""" + + if not zarr_available: + raise RuntimeError("Can't write to zarr file. Please install zarr.") + + # == Create some internal functions for doing the read == + # Function to perform a recursive clone of the tree structure + def _copy_to_file(memgroup, group): + + # Copy over attributes + if memgroup.comm.rank == 0: + copyattrs( + memgroup.attrs, group.attrs, convert_strings=convert_attribute_strings + ) + + # Sort the items to ensure we insert in a consistent order across ranks + for key in sorted(memgroup): + + item = memgroup[key] + + # If group, create the entry and the recurse into it + if is_group(item): + if memgroup.comm.rank == 0: + group.create_group(key) + memgroup.comm.Barrier() + _copy_to_file(item, group[key]) + + # If dataset, create dataset + else: + # Check if we are in a distributed dataset + if isinstance(item, MemDatasetDistributed): + + data = check_unicode(item) + + logger.error(f"chunk settings: {item.chunks}") + + # Write to file from MPIArray + data.to_file( + group, + key, + chunks=item.chunks, + compression=item.compression, + compression_opts=item.compression_opts, + file_format=fileformats.Zarr, + ) + dset = group[key] + + if memgroup.comm.rank == 0 and hints: + dset.attrs["__memh5_distributed_dset"] = True + + # Create common dataset (collective) + else: + + # Convert from unicode to bytestring + if convert_dataset_strings: + data = ensure_bytestring(item.data) + else: + data = check_unicode(item) + + # Write common data from rank 0 + if memgroup.comm.rank == 0: + dset = group.create_dataset( + key, + shape=data.shape, + dtype=data.dtype, + chunks=item.chunks, + **fileformats.Zarr.compression_kwargs( + item.compression, item.compression_opts + ), + ) + + dset[:] = data + + if hints: + dset.attrs["__memh5_distributed_dset"] = False + + # Copy attributes over into dataset + if memgroup.comm.rank == 0: + copyattrs( + item.attrs, + dset.attrs, + convert_strings=convert_attribute_strings, + ) + + # Make sure file exists + if group.comm.rank == 0: + zarr.open_group(store=fname, mode=mode) + group.comm.Barrier() + + # Open file on all ranks + + with fileformats.ZarrProcessSynchronizer( + f".{fname}.sync", group.comm + ) as synchronizer, zarr.open_group( + store=fname, mode="r+", synchronizer=synchronizer + ) as f: + # Start recursive file write + _copy_to_file(group, f) + + if hints and group.comm.rank == 0: + f.attrs["__memh5_distributed_file"] = True + + # Final synchronisation + group.comm.Barrier() + + +def _distributed_group_from_file( + fname, + comm=None, + _=True, # usually `hints`, but hints do not do anything in this method + convert_dataset_strings=False, + convert_attribute_strings=True, + file_format=fileformats.HDF5, + **kwargs, ): """ - Restore full tree from an HDF5 file into a distributed memh5 object. + Restore full tree from an HDF5 file or Zarr group into a distributed memh5 object. A `selections=` parameter may be supplied as parts of 'kwargs'. See `_deep_group_copy' for a description. @@ -2565,8 +3016,13 @@ def _copy_from_file(h5group, memgroup, selections=None): distributed_axis = item.attrs.get("__memh5_distributed_axis", 0) # Read from file into MPIArray - pdata = mpiarray.MPIArray.from_hdf5( - h5group, key, axis=distributed_axis, comm=comm, sel=selection + pdata = mpiarray.MPIArray.from_file( + h5group, + key, + axis=distributed_axis, + comm=comm, + sel=selection, + file_format=file_format, ) # Create dataset from MPIArray @@ -2594,10 +3050,14 @@ def _copy_from_file(h5group, memgroup, selections=None): # Copy attributes over into dataset _copy_attrs_bcast(item, dset, convert_strings=convert_attribute_strings) - # Open file on all ranks - with misc.open_h5py_mpi(fname, "r", comm=comm) as f: + if file_format == fileformats.HDF5: + # Open file on all ranks + with misc.open_h5py_mpi(fname, "r", comm=comm) as f: - # Start recursive file read + # Start recursive file read + _copy_from_file(f, group, selections) + else: + f = file_format.open(fname, "r") _copy_from_file(f, group, selections) # Final synchronisation diff --git a/caput/misc.py b/caput/misc.py index de7c04f3..e0fd2f40 100644 --- a/caput/misc.py +++ b/caput/misc.py @@ -231,7 +231,9 @@ def open_h5py_mpi(f, mode, use_mpi=True, comm=None): fh = f fh.opened = False else: - raise ValueError("Did not receive a h5py.File or filename") + raise ValueError( + f"Can't write to {f} (Expected a h5py.File, h5py.Group or str filename)." + ) fh.is_mpi = fh.file.driver == "mpio" diff --git a/caput/mpiarray.py b/caput/mpiarray.py index ddd3a153..e3e3ece6 100644 --- a/caput/mpiarray.py +++ b/caput/mpiarray.py @@ -96,7 +96,7 @@ import numpy as np -from caput import mpiutil, misc +from caput import fileformats, mpiutil, misc logger = logging.getLogger(__name__) @@ -631,12 +631,61 @@ def from_hdf5(cls, f, dataset, comm=None, axis=0, sel=None): ------- array : MPIArray """ - # Don't bother using MPI where the axis is not zero. It's probably just slower. - # TODO: with tuning this might not be true. Keep an eye on this. - use_mpi = axis > 0 + return cls.from_file(f, dataset, comm, axis, sel, file_format=fileformats.HDF5) - # Read the file. Opening with MPI if requested, and we can - fh = misc.open_h5py_mpi(f, "r", use_mpi=use_mpi, comm=comm) + @classmethod + def from_file( + cls, f, dataset, comm=None, axis=0, sel=None, file_format=fileformats.HDF5 + ): + """Read MPIArray from an HDF5 dataset or Zarr array on disk in parallel. + + Parameters + ---------- + f : filename, or `h5py.File` object + File to read dataset from. + dataset : string + Name of dataset to read from. Must exist. + comm : MPI.Comm, optional + MPI communicator to distribute over. If `None` optional, use + `MPI.COMM_WORLD`. + axis : int, optional + Axis over which the read should be distributed. This can be used + to select the most efficient axis for the reading. + sel : tuple, optional + A tuple of slice objects used to make a selection from the array + *before* reading. The output will be this selection from the dataset + distributed over the given axis. + file_format : `fileformats.HDF5` or `fileformats.Zarr` + File format to use. Default `fileformats.HDF5`. + + Returns + ------- + array : MPIArray + """ + if file_format == fileformats.HDF5: + # Don't bother using MPI where the axis is not zero. It's probably just slower. + # TODO: with tuning this might not be true. Keep an eye on this. + use_mpi = axis > 0 + + # Read the file. Opening with MPI if requested, and we can + fh = misc.open_h5py_mpi(f, "r", use_mpi=use_mpi, comm=comm) + elif file_format == fileformats.Zarr: + # Blosc may share incorrect global state amongst processes causing programs to hang. + # See https://zarr.readthedocs.io/en/stable/tutorial.html#parallel-computing-and-synchronization + try: + import numcodecs + except ImportError: + raise RuntimeError("Install numcodecs to read from zarr files.") + numcodecs.blosc.use_threads = False + + if isinstance(f, str): + fh = file_format.open(f, "r") + elif isinstance(f, file_format.module.Group): + fh = f + else: + raise ValueError( + f"Can't write to {f} (Expected a {file_format.module.__name__}.Group or str filename)." + ) dset = fh[dataset] dshape = dset.shape # Shape of the underlying dataset @@ -669,34 +718,43 @@ def from_hdf5(cls, f, dataset, comm=None, axis=0, sel=None): sel = tuple(sel) # Split the axis to get the IO size under ~2GB (only if MPI-IO) - split_axis, partitions = dist_arr._partition_io(skip=(not fh.is_mpi)) - - # Check that there are no null slices, otherwise we need to turn off - # collective IO to work around an h5py issue (#965) - no_null_slices = dist_arr.global_shape[axis] >= dist_arr.comm.size - - # Only use collective IO if: - # - there are no null slices (h5py bug) - # - we are not distributed over axis=0 as there is no advantage for - # collective IO which is usually slow - # TODO: change if h5py bug fixed - # TODO: better would be a test on contiguous IO size - # TODO: do we need collective IO to read chunked data? - use_collective = fh.is_mpi and no_null_slices and axis > 0 + split_axis, partitions = dist_arr._partition_io( + skip=file_format == fileformats.HDF5 and not fh.is_mpi + ) - # Read using collective MPI-IO if specified - with dset.collective if use_collective else DummyContext(): + if file_format == fileformats.HDF5: + # Check that there are no null slices, otherwise we need to turn off + # collective IO to work around an h5py issue (#965) + no_null_slices = dist_arr.global_shape[axis] >= dist_arr.comm.size + + # Only use collective IO if: + # - there are no null slices (h5py bug) + # - we are not distributed over axis=0 as there is no advantage for + # collective IO which is usually slow + # TODO: change if h5py bug fixed + # TODO: better would be a test on contiguous IO size + # TODO: do we need collective IO to read chunked data? + use_collective = fh.is_mpi and no_null_slices and axis > 0 + + # Read using collective MPI-IO if specified + with dset.collective if use_collective else DummyContext(): + + # Loop over partitions of the IO and perform them + for part in partitions: + islice, fslice = _partition_sel( + sel, split_axis, dshape[split_axis], part + ) + dist_arr[fslice] = dset[islice] - # Loop over partitions of the IO and perform them + if fh.opened: + fh.close() + else: for part in partitions: islice, fslice = _partition_sel( sel, split_axis, dshape[split_axis], part ) dist_arr[fslice] = dset[islice] - if fh.opened: - fh.close() - return dist_arr def to_hdf5( @@ -712,12 +770,17 @@ def to_hdf5( Parameters ---------- - filename : str, h5py.File or h5py.Group + f : str, h5py.File or h5py.Group File to write dataset into. dataset : string Name of dataset to write into. Should not exist. + chunks + compression : str or int + Name or identifier of HDF5 compression filter. + compression_opts + See HDF5 documentation for compression filters. + Compression options for the dataset. """ - import h5py if not h5py.get_config().mpi: @@ -730,39 +793,35 @@ def to_hdf5( ) mode = "a" if create else "r+" - fh = misc.open_h5py_mpi(f, mode, self.comm) - start = self.local_offset[self.axis] - end = start + self.local_shape[self.axis] - - # Construct slices for axis - sel = ([slice(None, None)] * self.axis) + [slice(start, end)] - sel = _expand_sel(sel, self.ndim) - # Check that there are no null slices, otherwise we need to turn off # collective IO to work around an h5py issue (#965) no_null_slices = self.global_shape[self.axis] >= self.comm.size - # Split the axis to get the IO size under ~2GB (only if MPI-IO) - split_axis, partitions = self._partition_io(skip=(not fh.is_mpi)) - # Only use collective IO if: # - there are no null slices (h5py bug) # - we are not distributed over axis=0 as there is no advantage for # collective IO which is usually slow # - unless we want to use compression/chunking # TODO: change if h5py bug fixed + # https://github.com/h5py/h5py/issues/965 # TODO: better would be a test on contiguous IO size use_collective = ( fh.is_mpi and no_null_slices and (self.axis > 0 or compression is not None) ) - if fh.is_mpi and not use_collective: + if fh.is_mpi and (not use_collective): # Need to disable compression if we can't use collective IO + logger.error("Cannot use collective IO, disabling compression") chunks, compression, compression_opts = None, None, None - dset = fh.create_dataset( + sel = self._make_selections() + + # Split the axis to get the IO size under ~2GB (only if MPI-IO) + split_axis, partitions = self._partition_io(skip=(not fh.is_mpi)) + + fh.create_dataset( dataset, shape=self.global_shape, dtype=self.dtype, @@ -772,18 +831,161 @@ def to_hdf5( ) # Read using collective MPI-IO if specified - with dset.collective if use_collective else DummyContext(): + with fh[dataset].collective if use_collective else DummyContext(): # Loop over partitions of the IO and perform them for part in partitions: islice, fslice = _partition_sel( sel, split_axis, self.global_shape[split_axis], part ) - dset[islice] = self[fslice] + fh[dataset][islice] = self[fslice] if fh.opened: fh.close() + def to_zarr( + self, + f, + dataset, + create, + chunks, + compression, + compression_opts, + ): + """Parallel write into a contiguous Zarr dataset. + + Parameters + ---------- + f : str zarr.Group + File to write dataset into. + dataset : string + Name of dataset to write into. Should not exist. + chunks + compression : str or int + Name or identifier of HDF5 compression filter. + compression_opts + See HDF5 documentation for compression filters. + Compression options for the dataset. + """ + try: + import zarr + import numcodecs + except ImportError as err: + raise RuntimeError( + f"Can't write to zarr file. Please install zarr and numcodecs: {err}" + ) + + # Blosc may share incorrect global state amongst processes causing programs to hang. + # See https://zarr.readthedocs.io/en/stable/tutorial.html#parallel-computing-and-synchronization + numcodecs.blosc.use_threads = False + + mode = "a" if create else "r+" + extra_args = fileformats.Zarr.compression_kwargs( + compression=compression, + compression_opts=compression_opts, + ) + + lockfile = None + + if isinstance(f, str): + if self.comm.rank == 0 and create: + zarr.open(store=f, mode=mode) + lockfile = f".{f}.sync" + self.comm.Barrier() + group = zarr.open_group( + store=f, + mode="r+", + synchronizer=zarr.ProcessSynchronizer(lockfile), + ) + elif isinstance(f, zarr.Group): + if f.synchronizer is None: + raise ValueError( + "Got zarr.Group without synchronizer, can't perform parallel write." + ) + group = f + else: + raise ValueError( + f"Can't write to {f} (Expected a zarr.Group or str filename)." + ) + + sel = self._make_selections() + + # Split the axis + split_axis, partitions = self._partition_io(skip=True) + + if self.comm.rank == 0: + group.create_dataset( + dataset, + shape=self.global_shape, + dtype=self.dtype, + chunks=chunks, + **extra_args, + ) + self.comm.Barrier() + + for part in partitions: + islice, fslice = _partition_sel( + sel, split_axis, self.global_shape[split_axis], part + ) + group[dataset][islice] = self.local_array[fslice] + self.comm.Barrier() + if self.comm.rank == 0 and lockfile is not None: + fileformats.remove_file_or_dir(lockfile) + + def to_file( + self, + f, + dataset, + create=False, + chunks=None, + compression=None, + compression_opts=None, + file_format=fileformats.HDF5, + ): + """Parallel write into a contiguous HDF5/Zarr dataset. + + Parameters + ---------- + f : str, h5py.File, h5py.Group or zarr.Group + File to write dataset into. + dataset : string + Name of dataset to write into. Should not exist. + chunks + compression : str or int + Name or identifier of HDF5 compression filter. + compression_opts + See HDF5 documentation for compression filters. + Compression options for the dataset. + """ + if chunks is None and hasattr(self, "chunks"): + logger.error(f"getting chunking opts from mpiarray: {self.chunks}") + chunks = self.chunks + if compression is None and hasattr(self, "compression"): + logger.error(f"getting compression opts from mpiarray: {self.compression}") + + compression = self.compression + if compression_opts is None and hasattr(self, "compression_opts"): + logger.error( + f"getting compression_opts opts from mpiarray: {self.compression_opts}" + ) + + compression_opts = self.compression_opts + if file_format == fileformats.HDF5: + self.to_hdf5(f, dataset, create, chunks, compression, compression_opts) + elif file_format == fileformats.Zarr: + self.to_zarr(f, dataset, create, chunks, compression, compression_opts) + else: + raise ValueError(f"Unknown file format: {file_format}") + + def _make_selections(self): + """Make selections for writing local data to distributed file.""" + start = self.local_offset[self.axis] + end = start + self.local_shape[self.axis] + + # Construct slices for axis + sel = ([slice(None, None)] * self.axis) + [slice(start, end)] + return _expand_sel(sel, self.ndim) + def transpose(self, *axes): """Transpose the array axes. diff --git a/caput/pipeline.py b/caput/pipeline.py index 259e6c4e..5767e849 100644 --- a/caput/pipeline.py +++ b/caput/pipeline.py @@ -351,7 +351,7 @@ import yaml -from . import config, misc +from . import config, fileformats, misc # Set the module logger. @@ -1158,6 +1158,9 @@ class _OneAndOne(TaskBase): input_root = config.Property(default="None", proptype=str) output_root = config.Property(default="None", proptype=str) + output_format = config.file_format() + output_compression = config.Property(default=None, proptype=str) + output_compression_opts = config.Property(default=None) def __init__(self): # Inspect the `process` method to see how many arguments it takes. @@ -1263,7 +1266,13 @@ def read_process_write(self, input, input_filename, output_filename): output_dirname = os.path.dirname(output_filename) if not os.path.isdir(output_dirname): os.makedirs(output_dirname) - self.write_output(output_filename, output) + self.write_output( + output_filename, + output, + file_format=self.output_format, + compression=self.output_compression, + compression_opts=self.output_compression_opts, + ) return output def read_input(self, filename): @@ -1285,7 +1294,8 @@ def read_output(self, filename): raise NotImplementedError() - def write_output(self, filename, output): + @staticmethod + def write_output(filename, output, file_format=None, **kwargs): """Override to implement reading inputs from disk.""" raise NotImplementedError() @@ -1330,6 +1340,9 @@ class SingleBase(_OneAndOne): input_filename = config.Property(default="", proptype=str) output_filename = config.Property(default="", proptype=str) + output_format = config.file_format() + output_compression = config.Property(default=None, proptype=str) + output_compression_opts = config.Property(default=None) def next(self, input=None): """Should not need to override.""" @@ -1423,7 +1436,7 @@ def next(self, input=None): class H5IOMixin: - """Provides hdf5 IO for pipeline tasks. + """Provides hdf5/zarr IO for pipeline tasks. As a mixin, this must be combined (using multiple inheritance) with a subclass of `TaskBase`, providing the full task API. @@ -1436,14 +1449,16 @@ class H5IOMixin: # TODO, implement reading on disk (i.e. no copy to memory). # ondisk = config.Property(default=False, proptype=bool) - def read_input(self, filename): + @staticmethod + def read_input(filename): """Method for reading hdf5 input.""" from caput import memh5 return memh5.MemGroup.from_hdf5(filename, mode="r") - def read_output(self, filename): + @staticmethod + def read_output(filename): """Method for reading hdf5 output (from caches).""" # Replicate code from read_input in case read_input is overridden. @@ -1451,19 +1466,34 @@ def read_output(self, filename): return memh5.MemGroup.from_hdf5(filename, mode="r") - def write_output(self, filename, output): - """Method for writing hdf5 output. - - `output` to be written must be either a `memh5.MemGroup` or an - `h5py.Group` (which include `hdf5.File` objects). In the latter case - the buffer is flushed if `filename` points to the same file and a copy - is made otherwise. + @staticmethod + def write_output(filename, output, file_format=None, **kwargs): + """ + Method for writing hdf5/zarr output. + Parameters + ---------- + filename : str + File name + output : memh5.Group, zarr.Group or h5py.Group + `output` to be written. If this is a `h5py.Group` (which include `hdf5.File` objects) + the buffer is flushed if `filename` points to the same file and a copy is made otherwise. + file_format : fileformats.Zarr, fileformats.HDF5 or None + File format to use. If this is not specified, the file format is guessed based on the type of + `output` or the `filename`. If guessing is not successful, HDF5 is used. """ from caput import memh5 import h5py + file_format = fileformats.check_file_format(filename, file_format, output) + + try: + import zarr + except ImportError: + if file_format == fileformats.Zarr: + raise RuntimeError("Can't write to zarr file. Please install zarr.") + # Ensure parent directory is present. dirname = os.path.dirname(filename) if not os.path.isdir(dirname): @@ -1479,14 +1509,16 @@ def write_output(self, filename, output): # Lock file with misc.lock_file(filename, comm=output.comm) as fn: - output.to_hdf5(fn, mode="w") + output.to_file(fn, mode="w", file_format=file_format, **kwargs) + return - elif isinstance(output, h5py.Group): + if isinstance(output, h5py.Group): if os.path.isfile(filename) and os.path.samefile( output.file.filename, filename ): # `output` already lives in this file. output.flush() + else: # Copy to memory then to disk # XXX This can be made much more efficient using a direct copy. @@ -1495,6 +1527,24 @@ def write_output(self, filename, output): # Lock file as we write with misc.lock_file(filename, comm=out_copy.comm) as fn: out_copy.to_hdf5(fn, mode="w") + elif isinstance(output, zarr.Group): + if os.path.isdir(filename) and os.path.samefile( + output.store.path, filename + ): + pass + else: + logger.debug(f"Copying {output.store}:{output.path} to {filename}.") + from . import mpiutil + + if mpiutil.rank == 0: + n_copied, n_skipped, n_bytes_copied = zarr.copy_store( + output.store, + zarr.DirectoryStore(filename), + source_path=output.path, + ) + logger.debug( + f"Copied {n_copied} items ({n_bytes_copied} bytes), skipped {n_skipped} items." + ) class BasicContMixin: @@ -1534,14 +1584,29 @@ def read_output(self, filename): filename, distributed=self._distributed, comm=self._comm ) - def write_output(self, filename, output): - """Method for writing hdf5 output. + @staticmethod + def write_output(filename, output, file_format=None, **kwargs): + """ + Method for writing output to disk. + + Parameters + ---------- + filename : str + File name. + output : :class:`memh5.BasicCont` + Data to be written. + file_format : `fileformats.FileFormat` + File format to use. Default `fileformats.HDF5`. + + Returns + ------- - `output` to be written must be either a :class:`memh5.BasicCont` object. """ from caput import memh5 + file_format = fileformats.check_file_format(filename, file_format, output) + # Ensure parent directory is present. dirname = os.path.dirname(filename) if dirname != "" and not os.path.isdir(dirname): @@ -1558,7 +1623,7 @@ def write_output(self, filename, output): ) # Already in memory. - output.save(filename) + output.save(filename, file_format=file_format, **kwargs) class SingleH5Base(H5IOMixin, SingleBase): diff --git a/caput/tests/conftest.py b/caput/tests/conftest.py index 89e57bb9..ef536433 100644 --- a/caput/tests/conftest.py +++ b/caput/tests/conftest.py @@ -1,13 +1,14 @@ """Pytest fixtures and simple tasks that can be used by all unit tests.""" +import glob import tempfile import numpy as np import pytest -from caput.pipeline import PipelineStopIteration, TaskBase, IterBase -from caput.scripts.runner import cli -from caput import config +from ..pipeline import PipelineStopIteration, TaskBase, IterBase +from ..scripts.runner import cli +from .. import config, fileformats, mpiutil @pytest.fixture(scope="session") @@ -89,7 +90,7 @@ def read_input(self, filename): def read_output(self, filename): raise NotImplementedError() - def write_output(self, filename, output): + def write_output(self, filename, output, file_format=None, **kwargs): raise NotImplementedError() @@ -129,6 +130,7 @@ def run_pipeline(parameters=None, configstr=eggs_pipeline_conf): Holds the captured result. Try accessing e.g. `result.exit_code`, `result.output`. """ + with tempfile.NamedTemporaryFile("w+") as configfile: configfile.write(configstr) configfile.flush() @@ -139,3 +141,44 @@ def run_pipeline(parameters=None, configstr=eggs_pipeline_conf): return runner.invoke(cli, ["run", configfile.name]) else: return runner.invoke(cli, ["run", *parameters, configfile.name]) + + +@pytest.fixture +def h5_file(): + """Provides a file name and removes all files/dirs with the same prefix later.""" + fname = "tmp_test_memh5.h5" + yield fname + rm_all_files(fname) + + +@pytest.fixture +def zarr_file(): + """Provides a directory name and removes all files/dirs with the same prefix later.""" + fname = "tmp_test_memh5.zarr" + yield fname + rm_all_files(fname) + + +@pytest.fixture +def h5_file_distributed(): + """Provides a file name and removes all files/dirs with the same prefix later.""" + fname = "tmp_test_memh5_distributed.h5" + yield fname + if mpiutil.rank == 0: + rm_all_files(fname) + + +@pytest.fixture +def zarr_file_distributed(): + """Provides a directory name and removes all files/dirs with the same prefix later.""" + fname = "tmp_test_memh5.zarr" + yield fname + if mpiutil.rank == 0: + rm_all_files(fname) + + +def rm_all_files(file_name): + """Remove all files and directories starting with `file_name`.""" + file_names = glob.glob(file_name + "*") + for fname in file_names: + fileformats.remove_file_or_dir(fname) diff --git a/caput/tests/test_lint.py b/caput/tests/test_lint.py index d9e7edf0..0fc27768 100644 --- a/caput/tests/test_lint.py +++ b/caput/tests/test_lint.py @@ -46,10 +46,10 @@ def simple_config(): def write_to_file(config_json): - temp = tempfile.NamedTemporaryFile(mode="w+t", delete=False) - yaml.safe_dump(config_json, temp, encoding="utf-8") - temp.flush() - return temp.name + with tempfile.NamedTemporaryFile(mode="w+t", delete=False) as temp: + yaml.safe_dump(config_json, temp, encoding="utf-8") + temp.flush() + return temp.name def test_load_yaml(simple_config): diff --git a/caput/tests/test_memh5.py b/caput/tests/test_memh5.py index 48110487..a329fd64 100644 --- a/caput/tests/test_memh5.py +++ b/caput/tests/test_memh5.py @@ -1,357 +1,511 @@ """Unit tests for the memh5 module.""" - import datetime -import unittest -import os -import glob import gc import json import warnings -import numpy as np import h5py +import numpy as np +import pytest +from pytest_lazyfixture import lazy_fixture +import zarr + +from caput import memh5, fileformats + + +def test_ro_dict(): + """Test memh5.ro_dict.""" + a = {"a": 5} + a = memh5.ro_dict(a) + assert a["a"] == 5 + assert list(a.keys()) == ["a"] + # Convoluded test to make sure you can't write to it. + with pytest.raises(TypeError): + # pylint: disable=unsupported-assignment-operation + a["b"] = 6 + + +# Unit tests for MemGroup. + + +def test_memgroup_nested(): + """Test nested groups in MemGroup.""" + root = memh5.MemGroup() + l1 = root.create_group("level1") + l2 = l1.require_group("level2") + assert root["level1"] == l1 + assert root["level1/level2"] == l2 + assert root["level1/level2"].name == "/level1/level2" + + +def test_memgroup_create_dataset(): + """Test creating datasets in MemGroup.""" + g = memh5.MemGroup() + data = np.arange(100, dtype=np.float32) + g.create_dataset("data", data=data) + assert np.allclose(data, g["data"]) + + +def test_memgroup_recursive_create(): + """Test creating nested groups at once in MemGroup.""" + g = memh5.MemGroup() + with pytest.raises(ValueError): + g.create_group("") + g2 = g.create_group("level2/") + with pytest.raises(ValueError): + g2.create_group("/") + g2.create_group("/level22") + assert set(g.keys()) == {"level22", "level2"} + g.create_group("/a/b/c/d/") + gd = g["/a/b/c/d/"] + assert gd.name == "/a/b/c/d" + + +def test_memgroup_recursive_create_dataset(): + """Test creating nested datasets in MemGroup.""" + g = memh5.MemGroup() + data = np.arange(10) + g.create_dataset("a/ra", data=data) + assert memh5.is_group(g["a"]) + assert np.all(g["a/ra"][:] == data) + g["a"].create_dataset("/ra", data=data) + assert np.all(g["ra"][:] == data) + assert isinstance(g["a/ra"].parent, memh5.MemGroup) + + # Check that d keeps g in scope. + d = g["a/ra"] + del g + gc.collect() + assert np.all(d.file["ra"][:] == data) + + +def fill_test_file(f): + """Fill a file with some groups, datasets and attrs for testing.""" + l1 = f.create_group("level1") + l1.create_group("level2") + d1 = l1.create_dataset("large", data=np.arange(100)) + f.attrs["a"] = 5 + d1.attrs["b"] = 6 + + +@pytest.fixture +def filled_h5_file(h5_file): + """Provides an H5 file with some content.""" + with h5py.File(h5_file, "w") as f: + fill_test_file(f) + f["level1"]["level2"].attrs["small"] = np.arange(3) + f["level1"]["level2"].attrs["ndarray"] = np.ndarray([1, 2, 3]) + yield h5_file + + +@pytest.fixture +def filled_zarr_file(zarr_file): + """Provides an H5 file with some content.""" + with zarr.open_group(zarr_file, "w") as f: + fill_test_file(f) + f["level1"]["level2"].attrs["small"] = [0, 1, 2] + yield zarr_file + + +def assertGroupsEqual(a, b): + """Compare two groups.""" + assert list(a.keys()) == list(b.keys()) + assertAttrsEqual(a.attrs, b.attrs) + for key in a.keys(): + this_a = a[key] + this_b = b[key] + if not memh5.is_group(a[key]): + assertAttrsEqual(this_a.attrs, this_b.attrs) + assert np.allclose(this_a, this_b) + else: + assertGroupsEqual(this_a, this_b) -from caput import memh5 - - -class TestRODict(unittest.TestCase): - """Unit tests for ro_dict.""" - def test_everything(self): - a = {"a": 5} - a = memh5.ro_dict(a) - self.assertEqual(a["a"], 5) - self.assertEqual(list(a.keys()), ["a"]) - # Convoluded test to make sure you can't write to it. - try: - a["b"] = 6 - except TypeError: - correct = True +def assertAttrsEqual(a, b): + """Compare two attributes.""" + assert list(a.keys()) == list(b.keys()) + for key in a.keys(): + this_a = a[key] + this_b = b[key] + if hasattr(this_a, "shape"): + assert np.allclose(this_a, this_b) else: - correct = False - self.assertTrue(correct) - - -class TestGroup(unittest.TestCase): - """Unit tests for MemGroup.""" - - def test_nested(self): - root = memh5.MemGroup() - l1 = root.create_group("level1") - l2 = l1.require_group("level2") - self.assertTrue(root["level1"] == l1) - self.assertTrue(root["level1/level2"] == l2) - self.assertEqual(root["level1/level2"].name, "/level1/level2") - - def test_create_dataset(self): - g = memh5.MemGroup() - data = np.arange(100, dtype=np.float32) - g.create_dataset("data", data=data) - self.assertTrue(np.allclose(data, g["data"])) - - def test_recursive_create(self): - g = memh5.MemGroup() - self.assertRaises(ValueError, g.create_group, "") - g2 = g.create_group("level2/") - self.assertRaises(ValueError, g2.create_group, "/") - g2.create_group("/level22") - self.assertEqual(set(g.keys()), {"level22", "level2"}) - g.create_group("/a/b/c/d/") - gd = g["/a/b/c/d/"] - self.assertEqual(gd.name, "/a/b/c/d") - - def test_recursive_create_dataset(self): - g = memh5.MemGroup() - data = np.arange(10) - g.create_dataset("a/ra", data=data) - self.assertTrue(memh5.is_group(g["a"])) - self.assertTrue(np.all(g["a/ra"][:] == data)) - g["a"].create_dataset("/ra", data=data) - self.assertTrue(np.all(g["ra"][:] == data)) - self.assertIsInstance(g["a/ra"].parent, memh5.MemGroup) - - # Check that d keeps g in scope. - d = g["a/ra"] - del g - gc.collect() - self.assertTrue(np.all(d.file["ra"][:] == data)) - - -class TestH5Files(unittest.TestCase): - """Tests that make hdf5 objects, convert to mem and back.""" - - fname = "tmp_test_memh5.h5" - - def setUp(self): - with h5py.File(self.fname, "w") as f: - l1 = f.create_group("level1") - l2 = l1.create_group("level2") - d1 = l1.create_dataset("large", data=np.arange(100)) - f.attrs["a"] = 5 - d1.attrs["b"] = 6 - l2.attrs["small"] = np.arange(3) - - def assertGroupsEqual(self, a, b): - self.assertEqual(list(a.keys()), list(b.keys())) - self.assertAttrsEqual(a.attrs, b.attrs) - for key in a.keys(): - this_a = a[key] - this_b = b[key] - if not memh5.is_group(a[key]): - self.assertAttrsEqual(this_a.attrs, this_b.attrs) - self.assertTrue(np.allclose(this_a, this_b)) - else: - self.assertGroupsEqual(this_a, this_b) - - def assertAttrsEqual(self, a, b): - self.assertEqual(list(a.keys()), list(b.keys())) - for key in a.keys(): - this_a = a[key] - this_b = b[key] - if hasattr(this_a, "shape"): - self.assertTrue(np.allclose(this_a, this_b)) + assert this_a == this_b + + +@pytest.mark.parametrize( + "test_file,file_open_function", + [ + (lazy_fixture("filled_h5_file"), h5py.File), + (lazy_fixture("filled_zarr_file"), zarr.open_group), + ], +) +def test_file_sanity(test_file, file_open_function): + """Compare a file with itself.""" + with file_open_function(test_file, "r") as f: + assertGroupsEqual(f, f) + + +@pytest.mark.parametrize( + "test_file,file_open_function,file_format", + [ + (lazy_fixture("filled_h5_file"), h5py.File, fileformats.HDF5), + (lazy_fixture("filled_zarr_file"), zarr.open_group, fileformats.Zarr), + ], +) +def test_to_from_file(test_file, file_open_function, file_format): + """Tests that makes hdf5 objects, convert to mem and back.""" + m = memh5.MemGroup.from_file(test_file, file_format=file_format) + + # Check that read in file has same structure + with file_open_function(test_file, "r") as f: + assertGroupsEqual(f, m) + + m.to_file( + test_file + ".new", + file_format=file_format, + ) + + # Check that written file has same structure + with file_open_function(test_file + ".new", "r") as f: + assertGroupsEqual(f, m) + + +@pytest.mark.parametrize( + "test_file,file_format", + [ + (lazy_fixture("filled_h5_file"), fileformats.HDF5), + (lazy_fixture("filled_zarr_file"), fileformats.Zarr), + ], +) +def test_memdisk(test_file, file_format): + """Test MemDiskGroup.""" + f = memh5.MemDiskGroup(test_file, file_format=file_format) + assert set(f.keys()) == set(f._data.keys()) + m = memh5.MemDiskGroup(memh5.MemGroup.from_file(test_file, file_format=file_format)) + assert set(m.keys()) == set(f.keys()) + # Recursive indexing. + assert set(f["/level1/"].keys()) == set(m["/level1/"].keys()) + assert set(f.keys()) == set(m["/level1"]["/"].keys()) + assert np.all(f["/level1/large"][:] == m["/level1/large"]) + gf = f.create_group("/level1/level2/level3/") + gf.create_dataset("new", data=np.arange(5)) + gm = m.create_group("/level1/level2/level3/") + gm.create_dataset("new", data=np.arange(5)) + assert np.all( + f["/level1/level2/level3/new"][:] == m["/level1/level2/level3/new"][:] + ) + + +@pytest.mark.parametrize( + "compression,compression_opts,chunks", + [(None, None, None), ("bitshuffle", (None, "lz4"), (2, 3))], +) +@pytest.mark.parametrize( + "test_file,file_format", + [ + (lazy_fixture("filled_h5_file"), fileformats.HDF5), + (lazy_fixture("filled_zarr_file"), fileformats.Zarr), + ], +) +def test_compression(test_file, file_format, compression, compression_opts, chunks): + # add a new compressed dataset + f = memh5.MemDiskGroup.from_file(test_file, file_format=file_format) + rng = np.random.default_rng(12345) + f.create_dataset( + "new", + data=rng.random((5, 7)), + chunks=chunks, + compression=compression, + compression_opts=compression_opts, + ) + # f.flush() + f.save( + test_file + ".cmp", + convert_attribute_strings=True, + convert_dataset_strings=True, + file_format=file_format, + ) + # f.close() + + # read back compression parameters from file + with file_format.open(test_file + ".cmp") as fh: + if file_format is fileformats.HDF5: + if compression is not None: + # for some reason .compression doesn't get set... + assert str(fileformats.H5FILTER) in fh["new"]._filters + assert fh["new"].chunks == chunks + else: + if compression is None: + assert fh["new"].compressor is None + assert fh["new"].chunks == fh["new"].shape else: - self.assertEqual(this_a, this_b) - - def test_h5_sanity(self): - with h5py.File(self.fname, "r") as f: - self.assertGroupsEqual(f, f) - - def test_to_from_hdf5(self): - m = memh5.MemGroup.from_hdf5(self.fname) - - # Check that read in file has same structure - with h5py.File(self.fname, "r") as f: - self.assertGroupsEqual(f, m) - - m.to_hdf5(self.fname + ".new") - - # Check that written file has same structure - with h5py.File(self.fname + ".new", "r") as f: - self.assertGroupsEqual(f, m) - - def test_memdisk(self): - f = memh5.MemDiskGroup(self.fname) - self.assertEqual(set(f.keys()), set(f._data.keys())) - m = memh5.MemDiskGroup(memh5.MemGroup.from_hdf5(self.fname)) - self.assertEqual(set(m.keys()), set(f.keys())) - # Recursive indexing. - self.assertEqual(set(f["/level1/"].keys()), set(m["/level1/"].keys())) - self.assertEqual(set(f.keys()), set(m["/level1"]["/"].keys())) - self.assertTrue(np.all(f["/level1/large"][:] == m["/level1/large"])) - gf = f.create_group("/level1/level2/level3/") - gf.create_dataset("new", data=np.arange(5)) - gm = m.create_group("/level1/level2/level3/") - gm.create_dataset("new", data=np.arange(5)) - self.assertTrue( - np.all( - f["/level1/level2/level3/new"][:] == m["/level1/level2/level3/new"][:] - ) - ) - - def tearDown(self): - file_names = glob.glob(self.fname + "*") - for fname in file_names: - os.remove(fname) + assert fh["new"].compressor is not None + assert fh["new"].chunks == chunks class TempSubClass(memh5.MemDiskGroup): - pass - - -class TestMemDiskGroup(unittest.TestCase): - - fname = "temp_mdg.h5" + """A subclass of MemDiskGroup for testing.""" - def test_io(self): - - # Save a subclass of MemDiskGroup - tsc = TempSubClass() - tsc.create_dataset("dset", data=np.arange(10)) - tsc.save(self.fname) - - # Load it from disk - tsc2 = memh5.MemDiskGroup.from_file(self.fname) - tsc3 = memh5.MemDiskGroup.from_file(self.fname, ondisk=True) - - # Check that is is recreated with the correct type - self.assertIsInstance(tsc2, TempSubClass) - self.assertIsInstance(tsc3, TempSubClass) + pass - # Check that parent/etc is properly implemented. - # Turns out this is very hard so give up for now. - # self.assertIsInstance(tsc2['dset'].parent, TempSubClass) - # self.assertIsInstance(tsc3['dset'].parent, TempSubClass) - tsc3.close() - with memh5.MemDiskGroup.from_file(self.fname, mode="r", ondisk=True): - self.assertRaises(IOError, h5py.File, self.fname, "w") +@pytest.mark.parametrize( + "test_file,file_format", + [ + (lazy_fixture("h5_file"), fileformats.HDF5), + (lazy_fixture("zarr_file"), fileformats.Zarr), + ], +) +def test_io(test_file, file_format): + """Test I/O of MemDiskGroup.""" + # Save a subclass of MemDiskGroup + tsc = TempSubClass() + tsc.create_dataset("dset", data=np.arange(10)) + tsc.save(test_file, file_format=file_format) + + # Load it from disk + tsc2 = memh5.MemDiskGroup.from_file(test_file, file_format=file_format) + tsc3 = memh5.MemDiskGroup.from_file(test_file, ondisk=True, file_format=file_format) + + # Check that is is recreated with the correct type + assert isinstance(tsc2, TempSubClass) + assert isinstance(tsc3, TempSubClass) + + # Check that parent/etc is properly implemented. + # Turns out this is very hard so give up for now. + # self.assertIsInstance(tsc2['dset'].parent, TempSubClass) + # self.assertIsInstance(tsc3['dset'].parent, TempSubClass) + tsc3.close() + + with memh5.MemDiskGroup.from_file( + test_file, mode="r", ondisk=True, file_format=file_format + ): + # h5py will error if file already open + if file_format == fileformats.HDF5: + with pytest.raises(IOError): + file_format.open(test_file, "w") + # ...zarr will not + else: + file_format.open(test_file, "w") - with memh5.MemDiskGroup.from_file(self.fname, mode="r", ondisk=False): - f = h5py.File(self.fname, "w") + with memh5.MemDiskGroup.from_file( + test_file, mode="r", ondisk=False, file_format=file_format + ): + f = file_format.open(test_file, "w") + if file_format == fileformats.HDF5: f.close() - def tearDown(self): - file_names = glob.glob(self.fname + "*") - for fname in file_names: - os.remove(fname) - -class TestBasicCont(unittest.TestCase): - fname = "test_bc.h5" - history_dict = {"foo": {"bar": {"f": 23}, "foo": "bar"}, "bar": 0} +@pytest.fixture(name="history_dict") +def fixture_history_dict(): + """Provides dict with some content for testing.""" + return {"foo": {"bar": {"f": 23}, "foo": "bar"}, "bar": 0} + + +@pytest.fixture +def h5_basiccont_file(h5_file, history_dict): + """Provides a BasicCont file written to HDF5.""" + d = memh5.BasicCont() + d.create_dataset("a", data=np.arange(5)) + d.add_history("test", history_dict) + d.to_disk(h5_file) + yield h5_file, history_dict + + +@pytest.fixture +def zarr_basiccont_file(zarr_file, history_dict): + """Provides a BasicCont file written to Zarr.""" + d = memh5.BasicCont() + d.create_dataset("a", data=np.arange(5)) + d.add_history("test", history_dict) + d.to_disk(zarr_file, file_format=fileformats.Zarr) + yield zarr_file, history_dict + + +@pytest.mark.parametrize( + "test_file,file_format", + [ + (lazy_fixture("h5_basiccont_file"), fileformats.HDF5), + (lazy_fixture("zarr_basiccont_file"), fileformats.Zarr), + ], +) +def test_access(test_file, file_format): + """Test access to BasicCont content.""" + test_file = test_file[0] + d = memh5.BasicCont.from_file(test_file, file_format=file_format) + assert "history" in d._data + assert "index_map" in d._data + with pytest.raises(KeyError): + d.__getitem__("history") + with pytest.raises(KeyError): + d.__getitem__("index_map") + + with pytest.raises(ValueError): + d.create_group("a") + with pytest.raises(ValueError): + d.create_dataset("index_map/stuff", data=np.arange(5)) + + +@pytest.mark.parametrize( + "test_file,file_format", + [ + (lazy_fixture("h5_basiccont_file"), fileformats.HDF5), + (lazy_fixture("zarr_basiccont_file"), fileformats.Zarr), + ], +) +def test_history(test_file, file_format): + """Test history of BasicCont.""" + basic_cont, history_dict = test_file json_prefix = "!!_memh5_json:" - def setUp(self): - d = memh5.BasicCont() - d.create_dataset("a", data=np.arange(5)) - d.add_history("test", self.history_dict) - d.to_disk(self.fname) - - def test_access(self): - d = memh5.BasicCont.from_file(self.fname) - self.assertTrue("history" in d._data) - self.assertTrue("index_map" in d._data) - self.assertRaises(KeyError, d.__getitem__, "history") - self.assertRaises(KeyError, d.__getitem__, "index_map") - - self.assertRaises(ValueError, d.create_group, "a") - self.assertRaises( - ValueError, d.create_dataset, "index_map/stuff", data=np.arange(5) - ) - - def test_history(self): - # Check HDF5 file for config- and versiondump - with h5py.File(self.fname, "r") as f: - history = f["history"].attrs["test"] - assert history == self.json_prefix + json.dumps(self.history_dict) - - # add old format history - with h5py.File(self.fname, "r+") as f: - f["history"].create_group("old_history_format") - f["history/old_history_format"].attrs["foo"] = "bar" - - with memh5.BasicCont.from_file(self.fname) as m: - with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered. - warnings.simplefilter("always") - old_history_format = m.history["old_history_format"] - - # Expect exactly one warning about deprecated history format - assert len(w) == 1 - assert issubclass(w[-1].category, DeprecationWarning) - assert "deprecated" in str(w[-1].message) - - assert old_history_format == {"foo": "bar"} - - -class TestUnicodeDataset(unittest.TestCase): + # Check file for config- and versiondump + with file_format.open(basic_cont, "r") as f: + history = f["history"].attrs["test"] + # if file_format == fileformats.HDF5: + assert history == json_prefix + json.dumps(history_dict) + # else: + # assert history == history_dict + + # add old format history + with file_format.open(basic_cont, "r+") as f: + f["history"].create_group("old_history_format") + f["history/old_history_format"].attrs["foo"] = "bar" + + with memh5.BasicCont.from_file(basic_cont, file_format=file_format) as m: + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + old_history_format = m.history["old_history_format"] + + # Expect exactly one warning about deprecated history format + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "deprecated" in str(w[-1].message) + + assert old_history_format == {"foo": "bar"} + + +@pytest.mark.parametrize( + "test_file,file_format", + [ + (lazy_fixture("h5_file"), fileformats.HDF5), + (lazy_fixture("zarr_file"), fileformats.Zarr), + ], +) +def test_to_from__file_unicode(test_file, file_format): """Test that a unicode memh5 dataset is round tripped correctly.""" - - fname = "tmp_test_unicode.h5" - - def test_to_from_hdf5(self): - - udata = np.array(["Test", "this", "works"]) - sdata = udata.astype("S") - self.assertEqual(udata.dtype.kind, "U") - self.assertEqual(sdata.dtype.kind, "S") - - m = memh5.MemGroup() - udset = m.create_dataset("udata", data=udata) - sdset = m.create_dataset("sdata", data=sdata) - self.assertEqual(udset.dtype.kind, "U") - self.assertEqual(sdset.dtype.kind, "S") - - # Test a write without conversion. This should throw an exception - with self.assertRaises(TypeError): - m.to_hdf5(self.fname) - - # Write with conversion - m.to_hdf5( - self.fname, convert_attribute_strings=True, convert_dataset_strings=True - ) - - with h5py.File(self.fname, "r") as fh: - # pylint warns here that "Instance of 'Group' has no 'dtype' member" - # pylint: disable=E1101 - self.assertEqual(fh["udata"].dtype.kind, "S") - self.assertEqual(fh["sdata"].dtype.kind, "S") - - # Test a load without conversion, types should be bytestrings - m2 = memh5.MemGroup.from_hdf5(self.fname) - self.assertEqual(m2["udata"].dtype.kind, "S") - self.assertEqual(m2["sdata"].dtype.kind, "S") - # Check the dtype here, for some reason Python 2 thinks the arrays are equal - # and Python 3 does not even though both agree that the datatypes are different - self.assertTrue(m["udata"].dtype != m2["udata"].dtype) - self.assertTrue((m["sdata"].data == m2["sdata"].data).all()) - - # Test a load *with* conversion, types should be unicode - m3 = memh5.MemGroup.from_hdf5( - self.fname, convert_attribute_strings=True, convert_dataset_strings=True - ) - self.assertEqual(m3["udata"].dtype.kind, "U") - self.assertEqual(m3["sdata"].dtype.kind, "U") - self.assertTrue((m["udata"].data == m3["udata"].data).all()) - self.assertTrue((m["udata"].data == m3["sdata"].data).all()) - - def test_failure(self): - # Test that we fail when trying to write a non ASCII character - - udata = np.array(["\u03B2"]) - - m = memh5.MemGroup() - m.create_dataset("udata", data=udata) - - with self.assertRaises(TypeError): - m.to_hdf5(self.fname) - - def tearDown(self): - file_names = glob.glob(self.fname + "*") - for fname in file_names: - os.remove(fname) - - -class TestMapJSON(unittest.TestCase): + udata = np.array(["Test", "this", "works"]) + sdata = udata.astype("S") + assert udata.dtype.kind == "U" + assert sdata.dtype.kind == "S" + + m = memh5.MemGroup() + udset = m.create_dataset("udata", data=udata) + sdset = m.create_dataset("sdata", data=sdata) + assert udset.dtype.kind == "U" + assert sdset.dtype.kind == "S" + + # Test a write without conversion. This should throw an exception + with pytest.raises(TypeError): + m.to_file(test_file, file_format=file_format) + + # Write with conversion + m.to_file( + test_file, + convert_attribute_strings=True, + convert_dataset_strings=True, + file_format=file_format, + ) + + with file_format.open(test_file, "r") as fh: + # pylint warns here that "Instance of 'Group' has no 'dtype' member" + # pylint: disable=E1101 + assert fh["udata"].dtype.kind == "S" + assert fh["sdata"].dtype.kind == "S" + + # Test a load without conversion, types should be bytestrings + m2 = memh5.MemGroup.from_file(test_file, file_format=file_format) + assert m2["udata"].dtype.kind == "S" + assert m2["sdata"].dtype.kind == "S" + # Check the dtype here, for some reason Python 2 thinks the arrays are equal + # and Python 3 does not even though both agree that the datatypes are different + assert m["udata"].dtype != m2["udata"].dtype + assert (m["sdata"].data == m2["sdata"].data).all() + + # Test a load *with* conversion, types should be unicode + m3 = memh5.MemGroup.from_file( + test_file, + convert_attribute_strings=True, + convert_dataset_strings=True, + file_format=file_format, + ) + assert m3["udata"].dtype.kind == "U" + assert m3["sdata"].dtype.kind == "U" + assert (m["udata"].data == m3["udata"].data).all() + assert (m["udata"].data == m3["sdata"].data).all() + + +@pytest.mark.parametrize( + "test_file,file_format", + [ + (lazy_fixture("h5_file"), fileformats.HDF5), + (lazy_fixture("zarr_file"), fileformats.Zarr), + ], +) +def test_failure(test_file, file_format): + """Test that we fail when trying to write a non ASCII character.""" + udata = np.array(["\u03B2"]) + + m = memh5.MemGroup() + m.create_dataset("udata", data=udata) + + with pytest.raises(TypeError): + m.to_file(test_file, file_format=file_format) + + +@pytest.mark.parametrize( + "test_file,file_format", + [ + (lazy_fixture("h5_file"), fileformats.HDF5), + (lazy_fixture("zarr_file"), fileformats.Zarr), + ], +) +def test_to_from_hdf5(test_file, file_format): """Test that a memh5 dataset JSON serialization is done correctly.""" + json_prefix = "!!_memh5_json:" + data = {"foo": {"bar": [1, 2, 3], "fu": "1"}} + time = datetime.datetime.now() + + m = memh5.MemGroup() + m.attrs["data"] = data + m.attrs["datetime"] = {"datetime": time} + m.attrs["ndarray"] = np.ndarray([1, 2, 3]) + + m.to_file(test_file, file_format=file_format) + with file_format.open(test_file, "r") as f: + assert f.attrs["data"] == json_prefix + json.dumps(data) + assert f.attrs["datetime"] == json_prefix + json.dumps( + {"datetime": time.isoformat()} + ) - fname = "tmp_test_json.h5" - - def test_to_from_hdf5(self): - json_prefix = "!!_memh5_json:" - data = {"foo": {"bar": [1, 2, 3], "fu": "1"}} - time = datetime.datetime.now() - - m = memh5.MemGroup() - m.attrs["data"] = data - m.attrs["datetime"] = {"datetime": time} - - m.to_hdf5(self.fname) - with h5py.File(self.fname, "r") as f: - assert f.attrs["data"] == json_prefix + json.dumps(data) - assert f.attrs["datetime"] == json_prefix + json.dumps( - {"datetime": time.isoformat()} - ) - - m2 = memh5.MemGroup.from_hdf5(self.fname) - assert m2.attrs["data"] == data - assert m2.attrs["datetime"] == {"datetime": time.isoformat()} - - def test_failure(self): - """Test that we get a TypeError if we try to serialize something else""" - m = memh5.MemGroup() - m.attrs["non_serializable"] = {"datetime": self} - - with self.assertRaises(TypeError): - m.to_hdf5(self.fname) - - def tearDown(self): - file_names = glob.glob(self.fname + "*") - for fname in file_names: - os.remove(fname) - - -if __name__ == "__main__": - unittest.main() + m2 = memh5.MemGroup.from_file(test_file, file_format=file_format) + assert m2.attrs["data"] == data + assert m2.attrs["datetime"] == {"datetime": time.isoformat()} + + +@pytest.mark.parametrize( + "test_file,file_format", + [ + (lazy_fixture("h5_file"), fileformats.HDF5), + (lazy_fixture("zarr_file"), fileformats.Zarr), + ], +) +def test_json_failure(test_file, file_format): + """Test that we get a TypeError if we try to serialize something else.""" + m = memh5.MemGroup() + m.attrs["non_serializable"] = {"datetime": object} + + with pytest.raises(TypeError): + m.to_file(test_file, file_format=file_format) diff --git a/caput/tests/test_memh5_parallel.py b/caput/tests/test_memh5_parallel.py index 4fcffbec..abbbe3b5 100644 --- a/caput/tests/test_memh5_parallel.py +++ b/caput/tests/test_memh5_parallel.py @@ -1,201 +1,231 @@ """Unit tests for the parallel features of the memh5 module.""" - -import unittest -import os -import glob - +import pytest +from pytest_lazyfixture import lazy_fixture import numpy as np import h5py +import zarr -from caput import memh5, mpiarray, mpiutil +from caput import fileformats, memh5, mpiarray, mpiutil comm = mpiutil.world rank, size = mpiutil.rank, mpiutil.size -class TestMemGroupDistributed(unittest.TestCase): - """Unit tests for MemGroup.""" - - fname = "tmp_test_memh5_distributed.h5" - - def test_create_dataset(self): - - global_data = np.arange(size * 5 * 10, dtype=np.float32) - local_data = global_data.reshape(size, -1, 10)[rank] - d_array = mpiarray.MPIArray.wrap(local_data, axis=0) - d_array_T = d_array.redistribute(axis=1) - - # Check that we must specify in advance if the dataset is distributed - g = memh5.MemGroup() - if comm is not None: - self.assertRaises(RuntimeError, g.create_dataset, "data", data=d_array) - - g = memh5.MemGroup(distributed=True) - - # Create an array from data - g.create_dataset("data", data=d_array, distributed=True) - - # Create an array from data with a different distribution - g.create_dataset("data_T", data=d_array, distributed=True, distributed_axis=1) - - # Create an empty array with a specified shape - g.create_dataset( - "data2", - shape=(size * 5, 10), - dtype=np.float64, - distributed=True, - distributed_axis=1, - ) - self.assertTrue(np.allclose(d_array, g["data"][:])) - self.assertTrue(np.allclose(d_array_T, g["data_T"][:])) - if comm is not None: - self.assertEqual(d_array_T.local_shape, g["data2"].local_shape) - - # Test global indexing - self.assertTrue((g["data"][rank * 5] == local_data[0]).all()) - - def test_io(self): - - # Create distributed memh5 object - g = memh5.MemGroup(distributed=True) - g.attrs["rank"] = rank - - # Create an empty array with a specified shape - pdset = g.create_dataset( - "parallel_data", - shape=(size, 10), - dtype=np.float64, - distributed=True, - distributed_axis=0, - ) - pdset[:] = rank - pdset.attrs["const"] = 17 - - # Create an empty array with a specified shape - sdset = g.create_dataset("serial_data", shape=(size * 5, 10), dtype=np.float64) - sdset[:] = rank - sdset.attrs["const"] = 18 - - # Create nested groups - g.create_group("hello/world") - - # Test round tripping unicode data - g.create_dataset("unicode_data", data=np.array(["hello"])) - - g.to_hdf5( - self.fname, convert_attribute_strings=True, convert_dataset_strings=True - ) - - # Test that the HDF5 file has the correct structure - with h5py.File(self.fname, "r") as f: - - # Test that the file attributes are correct - self.assertTrue(f["parallel_data"].attrs["const"] == 17) - - # Test that the parallel dataset has been written correctly - self.assertTrue((f["parallel_data"][:, 0] == np.arange(size)).all()) - self.assertTrue(f["parallel_data"].attrs["const"] == 17) - - # Test that the common dataset has been written correctly (i.e. by rank=0) - self.assertTrue((f["serial_data"][:] == 0).all()) - self.assertTrue(f["serial_data"].attrs["const"] == 18) - - # Check group structure is correct - self.assertIn("hello", f) - self.assertIn("world", f["hello"]) - - # Test that the read in group has the same structure as the original - g2 = memh5.MemGroup.from_hdf5( - self.fname, - distributed=True, - convert_attribute_strings=True, - convert_dataset_strings=True, - ) - - # Check that the parallel data is still the same - self.assertTrue((g2["parallel_data"][:] == g["parallel_data"][:]).all()) - - # Check that the serial data is all zeros (should not be the same as before) - self.assertTrue((g2["serial_data"][:] == np.zeros_like(sdset[:])).all()) +def test_create_dataset(): + """Test for creating datasets in MemGroup.""" + global_data = np.arange(size * 5 * 10, dtype=np.float32) + local_data = global_data.reshape(size, -1, 10)[rank] + d_array = mpiarray.MPIArray.wrap(local_data, axis=0) + d_array_T = d_array.redistribute(axis=1) + + # Check that we must specify in advance if the dataset is distributed + g = memh5.MemGroup() + if comm is not None: + with pytest.raises(RuntimeError): + g.create_dataset("data", data=d_array) + + g = memh5.MemGroup(distributed=True) + + # Create an array from data + g.create_dataset("data", data=d_array, distributed=True) + + # Create an array from data with a different distribution + g.create_dataset("data_T", data=d_array, distributed=True, distributed_axis=1) + + # Create an empty array with a specified shape + g.create_dataset( + "data2", + shape=(size * 5, 10), + dtype=np.float64, + distributed=True, + distributed_axis=1, + ) + assert np.allclose(d_array, g["data"][:]) + assert np.allclose(d_array_T, g["data_T"][:]) + if comm is not None: + assert d_array_T.local_shape == g["data2"].local_shape + + # Test global indexing + assert (g["data"][rank * 5] == local_data[0]).all() + + +@pytest.mark.parametrize( + "compression,compression_opts,chunks", + [ + (None, None, None), + ("bitshuffle", (None, "lz4"), (size // 2 + ((size // 2) == 0), 3)), + ], +) +@pytest.mark.parametrize( + "test_file,file_open_function,file_format", + [ + (lazy_fixture("h5_file_distributed"), h5py.File, fileformats.HDF5), + ( + lazy_fixture("zarr_file_distributed"), + zarr.open_group, + fileformats.Zarr, + ), + ], +) +def test_io( + test_file, file_open_function, file_format, compression, compression_opts, chunks +): + """Test for I/O in MemGroup.""" + + # Create distributed memh5 object + g = memh5.MemGroup(distributed=True) + g.attrs["rank"] = rank + + # Create an empty array with a specified shape + pdset = g.create_dataset( + "parallel_data", + shape=(size, 10), + dtype=np.float64, + distributed=True, + distributed_axis=0, + compression=compression, + compression_opts=compression_opts, + chunks=chunks, + ) + pdset[:] = rank + pdset.attrs["const"] = 17 + + # Create an empty array with a specified shape + sdset = g.create_dataset("serial_data", shape=(size * 5, 10), dtype=np.float64) + sdset[:] = rank + sdset.attrs["const"] = 18 + + # Create nested groups + g.create_group("hello/world") + + # Test round tripping unicode data + g.create_dataset("unicode_data", data=np.array(["hello"])) + + g.to_file( + test_file, + convert_attribute_strings=True, + convert_dataset_strings=True, + file_format=file_format, + ) + + # Test that the HDF5 file has the correct structure + with file_open_function(test_file, "r") as f: + + # Test that the file attributes are correct + assert f["parallel_data"].attrs["const"] == 17 + + # Test that the parallel dataset has been written correctly + assert (f["parallel_data"][:, 0] == np.arange(size)).all() + assert f["parallel_data"].attrs["const"] == 17 + + # Test that the common dataset has been written correctly (i.e. by rank=0) + assert (f["serial_data"][:] == 0).all() + assert f["serial_data"].attrs["const"] == 18 # Check group structure is correct - self.assertIn("hello", g2) - self.assertIn("world", g2["hello"]) - - # Check the unicode dataset - self.assertEqual(g2["unicode_data"].dtype.kind, "U") - self.assertEqual(g2["unicode_data"][0], "hello") - - # Check the attributes - self.assertTrue(g2["parallel_data"].attrs["const"] == 17) - self.assertTrue(g2["serial_data"].attrs["const"] == 18) - - def tearDown(self): - if rank == 0: - file_names = glob.glob(self.fname + "*") - for fname in file_names: - os.remove(fname) - - -class TestMemDiskGroupDistributed(unittest.TestCase): - - fname = "tmp_parallel_dg.h5" - - def test_misc(self): - - dg = memh5.MemDiskGroup(distributed=True) - - pdset = dg.create_dataset( - "parallel_data", - shape=(10,), - dtype=np.float64, - distributed=True, - distributed_axis=0, - ) - # pdset[:] = dg._data.comm.rank - pdset[:] = rank - # Test successfully added - self.assertIn("parallel_data", dg) - - dg.save(self.fname) - - dg2 = memh5.MemDiskGroup.from_file(self.fname, distributed=True) - - # Test successful load - self.assertIn("parallel_data", dg2) - self.assertTrue((dg["parallel_data"][:] == dg2["parallel_data"][:]).all()) - - # self.assertRaises(NotImplementedError, dg.to_disk, self.fname) - - # Test refusal to base off a h5py object when distributed - with h5py.File(self.fname, "r") as f: - if comm is not None: - self.assertRaises( - ValueError, memh5.MemDiskGroup, data_group=f, distributed=True - ) - mpiutil.barrier() - - def tearDown(self): - - if rank == 0: - file_names = glob.glob(self.fname + "*") - for fname in file_names: - os.remove(fname) - - -class TestBasicCont(unittest.TestCase): - def test_redistribute(self): + assert "hello" in f + assert "world" in f["hello"] + + # Check compression/chunks + if file_format is fileformats.Zarr: + if chunks is None: + assert f["parallel_data"].chunks == f["parallel_data"].shape + assert f["parallel_data"].compressor is None + else: + assert f["parallel_data"].chunks == chunks + assert f["parallel_data"].compressor is not None + elif file_format is fileformats.HDF5: + # compression should be disabled + # (for some reason .compression is not set...) + assert str(fileformats.H5FILTER) not in f["parallel_data"]._filters + assert f["parallel_data"].chunks is None + + # Test that the read in group has the same structure as the original + g2 = memh5.MemGroup.from_file( + test_file, + distributed=True, + convert_attribute_strings=True, + convert_dataset_strings=True, + file_format=file_format, + ) + + # Check that the parallel data is still the same + assert (g2["parallel_data"][:] == g["parallel_data"][:]).all() + + # Check that the serial data is all zeros (should not be the same as before) + assert (g2["serial_data"][:] == np.zeros_like(sdset[:])).all() + + # Check group structure is correct + assert "hello" in g2 + assert "world" in g2["hello"] + + # Check the unicode dataset + assert g2["unicode_data"].dtype.kind == "U" + assert g2["unicode_data"][0] == "hello" + + # Check the attributes + assert g2["parallel_data"].attrs["const"] == 17 + assert g2["serial_data"].attrs["const"] == 18 + + +@pytest.mark.parametrize( + "test_file,file_open_function,file_format", + [ + (lazy_fixture("h5_file_distributed"), h5py.File, fileformats.HDF5), + ( + lazy_fixture("zarr_file_distributed"), + zarr.open_group, + fileformats.Zarr, + ), + ], +) +def test_misc(test_file, file_open_function, file_format): + """Misc tests for MemDiskGroupDistributed""" + + dg = memh5.MemDiskGroup(distributed=True) + + pdset = dg.create_dataset( + "parallel_data", + shape=(10,), + dtype=np.float64, + distributed=True, + distributed_axis=0, + ) + # pdset[:] = dg._data.comm.rank + pdset[:] = rank + # Test successfully added + assert "parallel_data" in dg + + dg.save(test_file, file_format=file_format) + + dg2 = memh5.MemDiskGroup.from_file( + test_file, distributed=True, file_format=file_format + ) + + # Test successful load + assert "parallel_data" in dg2 + assert (dg["parallel_data"][:] == dg2["parallel_data"][:]).all() + + # self.assertRaises(NotImplementedError, dg.to_disk, self.fname) + + # Test refusal to base off a h5py object when distributed + with file_open_function(test_file, "r") as f: + if comm is not None: + with pytest.raises(ValueError): + # MemDiskGroup will guess the file format + memh5.MemDiskGroup(data_group=f, distributed=True) + mpiutil.barrier() - g = memh5.BasicCont(distributed=True) - # Create an array from data - g.create_dataset("data", shape=(10, 10), distributed=True, distributed_axis=0) - self.assertEqual(g["data"].distributed_axis, 0) - g.redistribute(1) - self.assertEqual(g["data"].distributed_axis, 1) +def test_redistribute(): + """Test redistribute in BasicCont.""" + g = memh5.BasicCont(distributed=True) -if __name__ == "__main__": - unittest.main() + # Create an array from data + g.create_dataset("data", shape=(10, 10), distributed=True, distributed_axis=0) + assert g["data"].distributed_axis == 0 + g.redistribute(1) + assert g["data"].distributed_axis == 1 diff --git a/caput/tests/test_mpiarray.py b/caput/tests/test_mpiarray.py index 50d1c886..97a7db22 100644 --- a/caput/tests/test_mpiarray.py +++ b/caput/tests/test_mpiarray.py @@ -5,394 +5,415 @@ $ mpirun -np 4 python test_mpiarray.py """ -import os -import unittest - +import pytest +from pytest_lazyfixture import lazy_fixture +import h5py import numpy as np +import zarr -from caput import mpiutil, mpiarray - - -class TestMPIArray(unittest.TestCase): - def test_construction(self): - - arr = mpiarray.MPIArray((10, 11), axis=1) - - l, s, _ = mpiutil.split_local(11) +from caput import mpiutil, mpiarray, fileformats - # Check that global shape is set correctly - assert arr.global_shape == (10, 11) - assert arr.shape == (10, l) +def test_construction(): + """Test local/global shape construction of MPIArray.""" + arr = mpiarray.MPIArray((10, 11), axis=1) - assert arr.local_offset == (0, s) + l, s, _ = mpiutil.split_local(11) - assert arr.local_shape == (10, l) + # Check that global shape is set correctly + assert arr.global_shape == (10, 11) - def test_redistribution(self): + assert arr.shape == (10, l) - gshape = (1, 11, 2, 14, 3, 4) - nelem = np.prod(gshape) - garr = np.arange(nelem).reshape(gshape) + assert arr.local_offset == (0, s) - _, s0, e0 = mpiutil.split_local(11) - _, s1, e1 = mpiutil.split_local(14) - _, s2, e2 = mpiutil.split_local(4) + assert arr.local_shape == (10, l) - arr = mpiarray.MPIArray(gshape, axis=1, dtype=np.int64) - arr[:] = garr[:, s0:e0] - arr2 = arr.redistribute(axis=3) - assert (arr2 == garr[:, :, :, s1:e1]).view(np.ndarray).all() +def test_redistribution(): + """Test redistributing an MPIArray.""" + gshape = (1, 11, 2, 14, 3, 4) + nelem = np.prod(gshape) + garr = np.arange(nelem).reshape(gshape) - arr3 = arr.redistribute(axis=5) - assert (arr3 == garr[:, :, :, :, :, s2:e2]).view(np.ndarray).all() + _, s0, e0 = mpiutil.split_local(11) + _, s1, e1 = mpiutil.split_local(14) + _, s2, e2 = mpiutil.split_local(4) - def test_gather(self): + arr = mpiarray.MPIArray(gshape, axis=1, dtype=np.int64) + arr[:] = garr[:, s0:e0] - rank = mpiutil.rank - size = mpiutil.size - block = 2 + arr2 = arr.redistribute(axis=3) + assert (arr2 == garr[:, :, :, s1:e1]).view(np.ndarray).all() - global_shape = (2, 3, size * block) - global_array = np.zeros(global_shape, dtype=np.float64) - global_array[..., :] = np.arange(size * block) + arr3 = arr.redistribute(axis=5) + assert (arr3 == garr[:, :, :, :, :, s2:e2]).view(np.ndarray).all() - arr = mpiarray.MPIArray(global_shape, dtype=np.float64, axis=2) - arr[:] = global_array[..., (rank * block) : ((rank + 1) * block)] - assert (arr.allgather() == global_array).all() +def test_gather(): + """Test MPIArray.gather().""" + rank = mpiutil.rank + size = mpiutil.size + block = 2 - gather_rank = 1 if size > 1 else 0 - ga = arr.gather(rank=gather_rank) + global_shape = (2, 3, size * block) + global_array = np.zeros(global_shape, dtype=np.float64) + global_array[..., :] = np.arange(size * block) - if rank == gather_rank: - assert (ga == global_array).all() - else: - assert ga is None + arr = mpiarray.MPIArray(global_shape, dtype=np.float64, axis=2) + arr[:] = global_array[..., (rank * block) : ((rank + 1) * block)] - def test_wrap(self): + assert (arr.allgather() == global_array).all() - ds = mpiarray.MPIArray((10, 17)) + gather_rank = 1 if size > 1 else 0 + ga = arr.gather(rank=gather_rank) - df = np.fft.rfft(ds, axis=1) + if rank == gather_rank: + assert (ga == global_array).all() + else: + assert ga is None - assert isinstance(df, np.ndarray) - da = mpiarray.MPIArray.wrap(df, axis=0) +def test_wrap(): + """Test MPIArray.wrap().""" + ds = mpiarray.MPIArray((10, 17)) - assert isinstance(da, mpiarray.MPIArray) - assert da.global_shape == (10, 9) + df = np.fft.rfft(ds, axis=1) - l0, _, _ = mpiutil.split_local(10) + assert isinstance(df, np.ndarray) - assert da.local_shape == (l0, 9) + da = mpiarray.MPIArray.wrap(df, axis=0) - if mpiutil.rank0: - df = df[:-1] + assert isinstance(da, mpiarray.MPIArray) + assert da.global_shape == (10, 9) - if mpiutil.size > 1: - with self.assertRaises(Exception): - mpiarray.MPIArray.wrap(df, axis=0) + l0, _, _ = mpiutil.split_local(10) - def test_io(self): + assert da.local_shape == (l0, 9) - import h5py + if mpiutil.rank0: + df = df[:-1] - # Cleanup directories - fname = "testdset.hdf5" + if mpiutil.size > 1: + with pytest.raises(Exception): + mpiarray.MPIArray.wrap(df, axis=0) - if mpiutil.rank0 and os.path.exists(fname): - os.remove(fname) - mpiutil.barrier() +@pytest.mark.parametrize( + "filename, file_open_function, file_format", + [ + (lazy_fixture("h5_file_distributed"), h5py.File, fileformats.HDF5), + ( + lazy_fixture("zarr_file_distributed"), + zarr.open_group, + fileformats.Zarr, + ), + ], +) +def test_io(filename, file_open_function, file_format): + """Test I/O of MPIArray.""" + gshape = (19, 17) - gshape = (19, 17) + ds = mpiarray.MPIArray(gshape, dtype=np.int64) - ds = mpiarray.MPIArray(gshape, dtype=np.int64) + ga = np.arange(np.prod(gshape)).reshape(gshape) - ga = np.arange(np.prod(gshape)).reshape(gshape) + _, s0, e0 = mpiutil.split_local(gshape[0]) + ds[:] = ga[s0:e0] - _, s0, e0 = mpiutil.split_local(gshape[0]) - ds[:] = ga[s0:e0] + ds.redistribute(axis=1).to_file( + filename, "testds", create=True, file_format=file_format + ) - ds.redistribute(axis=1).to_hdf5(fname, "testds", create=True) + if mpiutil.rank0: - if mpiutil.rank0: + with file_open_function(filename, "r") as f: + h5ds = f["testds"][:] - with h5py.File(fname, "r") as f: + assert (h5ds == ga).all() - h5ds = f["testds"][:] + ds2 = mpiarray.MPIArray.from_file(filename, "testds", file_format=file_format) - assert (h5ds == ga).all() + assert (ds2 == ds).all() - ds2 = mpiarray.MPIArray.from_hdf5(fname, "testds") + mpiutil.barrier() - assert (ds2 == ds).all() + # Check that reading over another distributed axis works + ds3 = mpiarray.MPIArray.from_file( + filename, "testds", axis=1, file_format=file_format + ) + assert ds3.shape[0] == gshape[0] + assert ds3.shape[1] == mpiutil.split_local(gshape[1])[0] + ds3 = ds3.redistribute(axis=0) + assert (ds3 == ds).all() + mpiutil.barrier() - mpiutil.barrier() + # Check a read with an arbitrary slice in there. This only checks the shape is correct. + ds4 = mpiarray.MPIArray.from_file( + filename, + "testds", + axis=1, + sel=(np.s_[3:10:2], np.s_[1:16:3]), + file_format=file_format, + ) + assert ds4.shape[0] == 4 + assert ds4.shape[1] == mpiutil.split_local(5)[0] + mpiutil.barrier() - # Check that reading over another distributed axis works - ds3 = mpiarray.MPIArray.from_hdf5(fname, "testds", axis=1) - assert ds3.shape[0] == gshape[0] - assert ds3.shape[1] == mpiutil.split_local(gshape[1])[0] - ds3 = ds3.redistribute(axis=0) - assert (ds3 == ds).all() - mpiutil.barrier() + # Check the read with a slice along the axis being read + ds5 = mpiarray.MPIArray.from_file( + filename, + "testds", + axis=1, + sel=(np.s_[:], np.s_[3:15:2]), + file_format=file_format, + ) + assert ds5.shape[0] == gshape[0] + assert ds5.shape[1] == mpiutil.split_local(6)[0] + ds5 = ds5.redistribute(axis=0) + assert (ds5 == ds[:, 3:15:2]).all() + mpiutil.barrier() - # Check a read with an arbitrary slice in there. This only checks the shape is correct. - ds4 = mpiarray.MPIArray.from_hdf5( - fname, "testds", axis=1, sel=(np.s_[3:10:2], np.s_[1:16:3]) - ) - assert ds4.shape[0] == 4 - assert ds4.shape[1] == mpiutil.split_local(5)[0] - mpiutil.barrier() + # Check the read with a slice along the axis being read + ds6 = mpiarray.MPIArray.from_file( + filename, + "testds", + axis=0, + sel=(np.s_[:], np.s_[3:15:2]), + file_format=file_format, + ) + ds6 = ds6.redistribute(axis=0) + assert (ds6 == ds[:, 3:15:2]).all() + mpiutil.barrier() - # Check the read with a slice along the axis being read - ds5 = mpiarray.MPIArray.from_hdf5( - fname, "testds", axis=1, sel=(np.s_[:], np.s_[3:15:2]) - ) - assert ds5.shape[0] == gshape[0] - assert ds5.shape[1] == mpiutil.split_local(6)[0] - ds5 = ds5.redistribute(axis=0) - assert (ds5 == ds[:, 3:15:2]).all() - mpiutil.barrier() - # Check the read with a slice along the axis being read - ds6 = mpiarray.MPIArray.from_hdf5( - fname, "testds", axis=0, sel=(np.s_[:], np.s_[3:15:2]) - ) - ds6 = ds6.redistribute(axis=0) - assert (ds6 == ds[:, 3:15:2]).all() - mpiutil.barrier() +def test_transpose(): + """Test MPIArray.transpose().""" + gshape = (1, 11, 2, 14) - if mpiutil.rank0 and os.path.exists(fname): - os.remove(fname) + l0, s0, _ = mpiutil.split_local(11) - def test_transpose(self): + arr = mpiarray.MPIArray(gshape, axis=1, dtype=np.int64) - gshape = (1, 11, 2, 14) + arr2 = arr.transpose(1, 3, 0, 2) - l0, s0, _ = mpiutil.split_local(11) + # Check type + assert isinstance(arr2, mpiarray.MPIArray) - arr = mpiarray.MPIArray(gshape, axis=1, dtype=np.int64) + # Check global shape + assert arr2.global_shape == (11, 14, 1, 2) - arr2 = arr.transpose(1, 3, 0, 2) + # Check local shape + assert arr2.local_shape == (l0, 14, 1, 2) - # Check type - assert isinstance(arr2, mpiarray.MPIArray) + # Check local offset + assert arr2.local_offset == (s0, 0, 0, 0) - # Check global shape - assert arr2.global_shape == (11, 14, 1, 2) + # Check axis + assert arr2.axis == 0 - # Check local shape - assert arr2.local_shape == (l0, 14, 1, 2) + # Do the same test with a tuple as argument to transpose + arr3 = arr.transpose((1, 3, 0, 2)) - # Check local offset - assert arr2.local_offset == (s0, 0, 0, 0) + # Check type + assert isinstance(arr3, mpiarray.MPIArray) - # Check axis - assert arr2.axis == 0 + # Check global shape + assert arr3.global_shape == (11, 14, 1, 2) - # Do the same test with a tuple as argument to transpose - arr3 = arr.transpose((1, 3, 0, 2)) + # Check local shape + assert arr3.local_shape == (l0, 14, 1, 2) - # Check type - assert isinstance(arr3, mpiarray.MPIArray) + # Check local offset + assert arr3.local_offset == (s0, 0, 0, 0) - # Check global shape - assert arr3.global_shape == (11, 14, 1, 2) + # Check axis + assert arr3.axis == 0 - # Check local shape - assert arr3.local_shape == (l0, 14, 1, 2) + # Do the same test with None as argument to transpose + arr4 = arr.transpose() - # Check local offset - assert arr3.local_offset == (s0, 0, 0, 0) + # Check type + assert isinstance(arr4, mpiarray.MPIArray) - # Check axis - assert arr3.axis == 0 + # Check global shape + assert arr4.global_shape == (14, 2, 11, 1) - # Do the same test with None as argument to transpose - arr4 = arr.transpose() + # Check local shape + assert arr4.local_shape == (14, 2, l0, 1) - # Check type - assert isinstance(arr4, mpiarray.MPIArray) + # Check local offset + assert arr4.local_offset == (0, 0, s0, 0) - # Check global shape - assert arr4.global_shape == (14, 2, 11, 1) + # Check axis + assert arr4.axis == 2 - # Check local shape - assert arr4.local_shape == (14, 2, l0, 1) - # Check local offset - assert arr4.local_offset == (0, 0, s0, 0) +def test_reshape(): + """Test MPIArray.reshape().""" + gshape = (1, 11, 2, 14) - # Check axis - assert arr4.axis == 2 + l0, s0, _ = mpiutil.split_local(11) - def test_reshape(self): + arr = mpiarray.MPIArray(gshape, axis=1, dtype=np.int64) - gshape = (1, 11, 2, 14) + arr2 = arr.reshape((None, 28)) - l0, s0, _ = mpiutil.split_local(11) + # Check type + assert isinstance(arr2, mpiarray.MPIArray) - arr = mpiarray.MPIArray(gshape, axis=1, dtype=np.int64) + # Check global shape + assert arr2.global_shape == (11, 28) - arr2 = arr.reshape((None, 28)) + # Check local shape + assert arr2.local_shape == (l0, 28) - # Check type - assert isinstance(arr2, mpiarray.MPIArray) + # Check local offset + assert arr2.local_offset == (s0, 0) - # Check global shape - assert arr2.global_shape == (11, 28) + # Check axis + assert arr2.axis == 0 - # Check local shape - assert arr2.local_shape == (l0, 28) - # Check local offset - assert arr2.local_offset == (s0, 0) +def test_global_getslice(): + """Test MPIArray.global_slice.""" + rank = mpiutil.rank + size = mpiutil.size - # Check axis - assert arr2.axis == 0 + darr = mpiarray.MPIArray((size * 5, 20), axis=0) - def test_global_getslice(self): + # Initialise the distributed array + for li, _ in darr.enumerate(axis=0): + darr[li] = 10 * (10 * rank + li) + np.arange(20) - rank = mpiutil.rank - size = mpiutil.size + # Construct numpy array which should be equivalent to the global array + whole_array = ( + 10 + * ( + 10 * np.arange(4.0)[:, np.newaxis] + np.arange(5.0)[np.newaxis, :] + ).flatten()[:, np.newaxis] + + np.arange(20)[np.newaxis, :] + ) - darr = mpiarray.MPIArray((size * 5, 20), axis=0) + # Extract the section for each rank distributed along axis=0 + local_array = whole_array[(rank * 5) : ((rank + 1) * 5)] - # Initialise the distributed array - for li, _ in darr.enumerate(axis=0): - darr[li] = 10 * (10 * rank + li) + np.arange(20) + # Extract the correct section for each rank distributed along axis=0 + local_array_T = whole_array[:, (rank * 5) : ((rank + 1) * 5)] - # Construct numpy array which should be equivalent to the global array - whole_array = ( - 10 - * ( - 10 * np.arange(4.0)[:, np.newaxis] + np.arange(5.0)[np.newaxis, :] - ).flatten()[:, np.newaxis] - + np.arange(20)[np.newaxis, :] - ) + # Check that these are the same + assert (local_array == darr).all() - # Extract the section for each rank distributed along axis=0 - local_array = whole_array[(rank * 5) : ((rank + 1) * 5)] + # Check a simple slice on the non-parallel axis + arr = darr.global_slice[:, 3:5] + res = local_array[:, 3:5] - # Extract the correct section for each rank distributed along axis=0 - local_array_T = whole_array[:, (rank * 5) : ((rank + 1) * 5)] + assert isinstance(arr, mpiarray.MPIArray) + assert (arr == res).all() - # Check that these are the same - assert (local_array == darr).all() + # Check a single element extracted from the non-parallel axis + arr = darr.global_slice[:, 3] + res = local_array[:, 3] + assert (arr == res).all() - # Check a simple slice on the non-parallel axis - arr = darr.global_slice[:, 3:5] - res = local_array[:, 3:5] + # These tests denpend on the size being at least 2. + if size > 1: + # Check a slice on the parallel axis + arr = darr.global_slice[:7, 3:5] - assert isinstance(arr, mpiarray.MPIArray) - assert (arr == res).all() + res = {0: local_array[:, 3:5], 1: local_array[:2, 3:5], 2: None, 3: None} - # Check a single element extracted from the non-parallel axis - arr = darr.global_slice[:, 3] - res = local_array[:, 3] - assert (arr == res).all() + assert arr == res[rank] if arr is None else (arr == res[rank]).all() - # These tests denpend on the size being at least 2. - if size > 1: - # Check a slice on the parallel axis - arr = darr.global_slice[:7, 3:5] + # Check a single element from the parallel axis + arr = darr.global_slice[7, 3:5] - res = {0: local_array[:, 3:5], 1: local_array[:2, 3:5], 2: None, 3: None} + res = {0: None, 1: local_array[2, 3:5], 2: None, 3: None} - assert arr == res[rank] if arr is None else (arr == res[rank]).all() + assert arr == res[rank] if arr is None else (arr == res[rank]).all() - # Check a single element from the parallel axis - arr = darr.global_slice[7, 3:5] + # Check a slice on the redistributed parallel axis + darr_T = darr.redistribute(axis=1) + arr = darr_T.global_slice[3:5, :7] - res = {0: None, 1: local_array[2, 3:5], 2: None, 3: None} + res = { + 0: local_array_T[3:5, :], + 1: local_array_T[3:5, :2], + 2: None, + 3: None, + } - assert arr == res[rank] if arr is None else (arr == res[rank]).all() + assert arr == res[rank] if arr is None else (arr == res[rank]).all() - # Check a slice on the redistributed parallel axis - darr_T = darr.redistribute(axis=1) - arr = darr_T.global_slice[3:5, :7] + # Check a slice that removes an axis + darr = mpiarray.MPIArray((10, 20, size * 5), axis=2) + dslice = darr.global_slice[:, 0, :] - res = { - 0: local_array_T[3:5, :], - 1: local_array_T[3:5, :2], - 2: None, - 3: None, - } + assert dslice.global_shape == (10, size * 5) + assert dslice.local_shape == (10, 5) - assert arr == res[rank] if arr is None else (arr == res[rank]).all() + # Check ellipsis and slice at the end + darr = mpiarray.MPIArray((size * 5, 20, 10), axis=0) + dslice = darr.global_slice[..., 4:9] - # Check a slice that removes an axis - darr = mpiarray.MPIArray((10, 20, size * 5), axis=2) - dslice = darr.global_slice[:, 0, :] + assert dslice.global_shape == (size * 5, 20, 5) + assert dslice.local_shape == (5, 20, 5) - assert dslice.global_shape == (10, size * 5) - assert dslice.local_shape == (10, 5) + # Check slice that goes off the end of the axis + darr = mpiarray.MPIArray((size, 136, 2048), axis=0) + dslice = darr.global_slice[..., 2007:2087] - # Check ellipsis and slice at the end - darr = mpiarray.MPIArray((size * 5, 20, 10), axis=0) - dslice = darr.global_slice[..., 4:9] + assert dslice.global_shape == (size, 136, 41) + assert dslice.local_shape == (1, 136, 41) - assert dslice.global_shape == (size * 5, 20, 5) - assert dslice.local_shape == (5, 20, 5) - # Check slice that goes off the end of the axis - darr = mpiarray.MPIArray((size, 136, 2048), axis=0) - dslice = darr.global_slice[..., 2007:2087] +def test_global_setslice(): + """Test setting MPIArray.global_slice.""" + rank = mpiutil.rank + size = mpiutil.size - assert dslice.global_shape == (size, 136, 41) - assert dslice.local_shape == (1, 136, 41) + darr = mpiarray.MPIArray((size * 5, 20), axis=0) - def test_global_setslice(self): + # Initialise the distributed array + for li, _ in darr.enumerate(axis=0): + darr[li] = 10 * (10 * rank + li) + np.arange(20) - rank = mpiutil.rank - size = mpiutil.size + # Construct numpy array which should be equivalent to the global array + whole_array = ( + 10 + * ( + 10 * np.arange(4.0)[:, np.newaxis] + np.arange(5.0)[np.newaxis, :] + ).flatten()[:, np.newaxis] + + np.arange(20)[np.newaxis, :] + ) - darr = mpiarray.MPIArray((size * 5, 20), axis=0) + # Extract the section for each rank distributed along axis=0 + local_array = whole_array[(rank * 5) : ((rank + 1) * 5)] + # Set slice - # Initialise the distributed array - for li, _ in darr.enumerate(axis=0): - darr[li] = 10 * (10 * rank + li) + np.arange(20) + # Check a simple assignment to a slice along the non-parallel axis + darr.global_slice[:, 6] = -2.0 + local_array[:, 6] = -2.0 - # Construct numpy array which should be equivalent to the global array - whole_array = ( - 10 - * ( - 10 * np.arange(4.0)[:, np.newaxis] + np.arange(5.0)[np.newaxis, :] - ).flatten()[:, np.newaxis] - + np.arange(20)[np.newaxis, :] - ) + assert (darr == local_array).all() - # Extract the section for each rank distributed along axis=0 - local_array = whole_array[(rank * 5) : ((rank + 1) * 5)] - # Set slice + # Check a partial assignment along the parallel axis + darr.global_slice[7:, 7:9] = -3.0 + whole_array[7:, 7:9] = -3.0 - # Check a simple assignment to a slice along the non-parallel axis - darr.global_slice[:, 6] = -2.0 - local_array[:, 6] = -2.0 + assert (darr == local_array).all() - assert (darr == local_array).all() + # Check assignment of a single index on the parallel axis + darr.global_slice[6] = np.arange(20.0) + whole_array[6] = np.arange(20.0) - # Check a partial assignment along the parallel axis - darr.global_slice[7:, 7:9] = -3.0 - whole_array[7:, 7:9] = -3.0 + assert (darr == local_array).all() - assert (darr == local_array).all() + # Check copy of one column into the other + darr.global_slice[:, 8] = darr.global_slice[:, 9] + whole_array[:, 8] = whole_array[:, 9] - # Check assignment of a single index on the parallel axis - darr.global_slice[6] = np.arange(20.0) - whole_array[6] = np.arange(20.0) - - assert (darr == local_array).all() - - # Check copy of one column into the other - darr.global_slice[:, 8] = darr.global_slice[:, 9] - whole_array[:, 8] = whole_array[:, 9] - - assert (darr == local_array).all() + assert (darr == local_array).all() # @@ -428,7 +449,3 @@ def test_global_setslice(self): # assert (td1['a'] == td2['a']).all() # assert (td1['b'] == td2['b']).all() # assert (td1.attrs['message'] == td2.attrs['message']) - - -if __name__ == "__main__": - unittest.main() diff --git a/caput/tests/test_selection.py b/caput/tests/test_selection.py index e3d97cdd..66db5294 100644 --- a/caput/tests/test_selection.py +++ b/caput/tests/test_selection.py @@ -1,11 +1,11 @@ """Serial version of the selection tests.""" -import glob -import os import pytest - +from pytest_lazyfixture import lazy_fixture import numpy as np from caput.memh5 import MemGroup +from caput import fileformats +from caput.tests.conftest import rm_all_files fsel = slice(1, 8, 2) isel = slice(1, 4) @@ -15,28 +15,61 @@ @pytest.fixture -def container_on_disk(datasets): - fname = "tmp_test_memh5_select.h5" +def h5_file_select(datasets, h5_file): + """Provides an HDF5 file with some content for testing.""" container = MemGroup() container.create_dataset("dset1", data=datasets[0].view()) container.create_dataset("dset2", data=datasets[1].view()) - container.to_hdf5(fname) - yield fname, datasets + container.to_hdf5(h5_file) + yield h5_file, datasets + rm_all_files(h5_file) - # tear down - file_names = glob.glob(fname + "*") - for fname in file_names: - os.remove(fname) +@pytest.fixture +def zarr_file_select(datasets, zarr_file): + """Provides a Zarr file with some content for testing.""" + container = MemGroup() + container.create_dataset("dset1", data=datasets[0].view()) + container.create_dataset("dset2", data=datasets[1].view()) + container.to_file(zarr_file, file_format=fileformats.Zarr) + yield zarr_file, datasets + rm_all_files(zarr_file) -def test_H5FileSelect(container_on_disk): + +@pytest.mark.parametrize( + "container_on_disk, file_format", + [ + (lazy_fixture("h5_file_select"), fileformats.HDF5), + (lazy_fixture("zarr_file_select"), fileformats.Zarr), + ], +) +def test_file_select(container_on_disk, file_format): """Tests that makes hdf5 objects and tests selecting on their axes.""" - m = MemGroup.from_hdf5(container_on_disk[0], selections=sel) + m = MemGroup.from_file( + container_on_disk[0], selections=sel, file_format=file_format + ) assert np.all(m["dset1"][:] == container_on_disk[1][0][(fsel, isel, slice(None))]) assert np.all(m["dset2"][:] == container_on_disk[1][1][(fsel, slice(None))]) + +@pytest.mark.parametrize( + "container_on_disk, file_format", + [ + (lazy_fixture("h5_file_select"), fileformats.HDF5), + pytest.param( + lazy_fixture("zarr_file_select"), + fileformats.Zarr, + marks=pytest.mark.xfail(reason="Zarr doesn't support index selections."), + ), + ], +) +def test_file_select_index(container_on_disk, file_format): + """Tests that makes hdf5 objects and tests selecting on their axes.""" + # now test index selection - m = MemGroup.from_hdf5(container_on_disk[0], selections=index_sel) + m = MemGroup.from_file( + container_on_disk[0], selections=index_sel, file_format=file_format + ) assert np.all(m["dset1"][:] == container_on_disk[1][0][index_sel["dset1"]]) assert np.all(m["dset2"][:] == container_on_disk[1][1][index_sel["dset2"]]) diff --git a/caput/tests/test_selection_parallel.py b/caput/tests/test_selection_parallel.py index f0444680..e44535bc 100644 --- a/caput/tests/test_selection_parallel.py +++ b/caput/tests/test_selection_parallel.py @@ -2,52 +2,65 @@ Needs to be run on 1, 2 or 4 MPI processes. """ -import glob -import os from mpi4py import MPI import numpy as np import pytest +from pytest_lazyfixture import lazy_fixture -from caput import mpiutil, mpiarray +from caput import mpiutil, mpiarray, fileformats from caput.memh5 import MemGroup +from caput.tests.conftest import rm_all_files comm = MPI.COMM_WORLD -@pytest.fixture(scope="module") -def container_on_disk(datasets): - - fname = "tmp_test_memh5_select_parallel.h5" - +@pytest.fixture +def container_on_disk(datasets, file_name, file_format): + """Prepare a file for the select_parallel tests.""" if comm.rank == 0: m1 = mpiarray.MPIArray.wrap(datasets[0], axis=0, comm=MPI.COMM_SELF) m2 = mpiarray.MPIArray.wrap(datasets[1], axis=0, comm=MPI.COMM_SELF) container = MemGroup(distributed=True, comm=MPI.COMM_SELF) container.create_dataset("dset1", data=m1, distributed=True) container.create_dataset("dset2", data=m2, distributed=True) - container.to_hdf5(fname) + container.to_file(file_name, file_format=file_format) comm.Barrier() - yield fname, datasets + yield file_name, datasets comm.Barrier() - # tear down - if comm.rank == 0: - file_names = glob.glob(fname + "*") - for fname in file_names: - os.remove(fname) + rm_all_files(file_name) + + +@pytest.fixture +def xfail_zarr_listsel(request): + file_format = request.getfixturevalue("file_format") + ind = request.getfixturevalue("ind") + + if file_format == fileformats.Zarr and isinstance(ind, (list, tuple)): + request.node.add_marker( + pytest.mark.xfail(reason="Zarr doesn't support list based indexing.") + ) +@pytest.mark.parametrize( + "file_name, file_format", + [ + (lazy_fixture("h5_file"), fileformats.HDF5), + (lazy_fixture("zarr_file"), fileformats.Zarr), + ], +) @pytest.mark.parametrize("fsel", [slice(1, 8, 2), slice(5, 8, 2)]) @pytest.mark.parametrize("isel", [slice(1, 4), slice(5, 8, 2)]) @pytest.mark.parametrize("ind", [slice(None), [0, 2, 7]]) -def test_H5FileSelect_distributed(container_on_disk, fsel, isel, ind): - """Load H5 into parallel container while down-selecting axes.""" +@pytest.mark.usefixtures("xfail_zarr_listsel") +def test_FileSelect_distributed(container_on_disk, fsel, isel, file_format, ind): + """Load H5/Zarr file into parallel container while down-selecting axes.""" if ind == slice(None): sel = {"dset1": (fsel, isel, slice(None)), "dset2": (fsel, slice(None))} @@ -57,8 +70,12 @@ def test_H5FileSelect_distributed(container_on_disk, fsel, isel, ind): # Tests are designed to run for 1, 2 or 4 processes assert 4 % comm.size == 0 - m = MemGroup.from_hdf5( - container_on_disk[0], selections=sel, distributed=True, comm=comm + m = MemGroup.from_file( + container_on_disk[0], + selections=sel, + distributed=True, + comm=comm, + file_format=file_format, ) d1 = container_on_disk[1][0][sel["dset1"]] diff --git a/caput/tod.py b/caput/tod.py index 45ad9419..c56d5fc7 100644 --- a/caput/tod.py +++ b/caput/tod.py @@ -15,6 +15,7 @@ from . import memh5 from . import mpiarray +from . import fileformats class TOData(memh5.BasicCont): @@ -101,13 +102,15 @@ class Reader: files : filename, `h5py.File` or list there-of or filename pattern Files containing data. Filename patterns with wild cards (e.g. "foo*.h5") are supported. + file_format : `fileformats.FileFormat` + File format to use. Default `None` (format will be guessed). """ # Controls the association between Reader classes and data classes. # Override with subclass of TOData. data_class = TOData - def __init__(self, files): + def __init__(self, files, file_format=None): # If files is a filename, or pattern, turn into list of files. if isinstance(files, str): @@ -118,9 +121,10 @@ def __init__(self, files): # Fetch all meta data. time = np.copy(data_empty.time) - first_file, toclose = memh5.get_h5py_File(files[0]) + first_file, toclose = memh5.get_file(files[0], file_format=file_format) datasets = _copy_non_time_data(first_file) - if toclose: + # Zarr arrays are flushed automatically flushed and closed + if toclose and (file_format == fileformats.HDF5): first_file.close() # Set the metadata attributes. diff --git a/doc/conf.py b/doc/conf.py index 5ac66aef..e04fe3db 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -30,7 +30,7 @@ def __getattr__(cls, name): return Mock() -# Do not mock up mpi4py. This is an "extra", and docs bbuild without it. +# Do not mock up mpi4py. This is an "extra", and docs build without it. # MOCK_MODULES = ['h5py', 'mpi4py'] MOCK_MODULES = ["h5py"] if on_rtd: diff --git a/requirements.txt b/requirements.txt index 5977204c..ec0d1acb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ -numpy>=1.16 -scipy +cachetools +click +cython h5py +numpy>=1.16 +psutil PyYAML +scipy skyfield>=1.31 -cython -click -cachetools -psutil \ No newline at end of file diff --git a/setup.py b/setup.py index f93ac1e3..d5ef0c44 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,15 @@ """, python_requires=">=3.6", install_requires=requires, - extras_require={"mpi": ["mpi4py>=1.3"], "profiling": ["psutil", "pyinstrument"]}, + extras_require={ + "mpi": ["mpi4py>=1.3"], + "compression": [ + "bitshuffle @ git+https://github.com/kiyo-masui/bitshuffle.git", + "numcodecs==0.7.3", + "zarr==2.8.1", + ], + "profiling": ["psutil", "pyinstrument"], + }, setup_requires=["cython"], # metadata for upload to PyPI author="Kiyo Masui, J. Richard Shaw", From ff1e14e25361c277afd8c318c6635bcfd02ee58f Mon Sep 17 00:00:00 2001 From: Richard Shaw Date: Wed, 9 Mar 2022 21:17:25 -0800 Subject: [PATCH 2/7] refactor(memh5): create a single routine to handle all distributed writes This significantly cleans up the handling of distributed file writes and condenses the logic into a single flow for all file types. --- caput/memh5.py | 676 ++++++++++++-------------------------- caput/mpiarray.py | 49 ++- caput/tests/test_memh5.py | 24 +- 3 files changed, 269 insertions(+), 480 deletions(-) diff --git a/caput/memh5.py b/caput/memh5.py index e864fc9b..72c5ad0b 100644 --- a/caput/memh5.py +++ b/caput/memh5.py @@ -479,12 +479,13 @@ def from_file( selections=None, convert_dataset_strings=False, convert_attribute_strings=True, - file_format=fileformats.HDF5, + file_format=None, **kwargs, ): """Create a new instance by copying from a file group. - Any keyword arguments are passed on to the constructor for `h5py.File` or `zarr.File`. + Any keyword arguments are passed on to the constructor for `h5py.File` or + `zarr.File`. Parameters ---------- @@ -505,8 +506,8 @@ def from_file( Try and convert attribute string types to unicode. Default is `True`. convert_dataset_strings : bool, optional Try and convert dataset string types to unicode. Default is `False`. - file_format : `fileformats.FileFormat` - File format to use. Default `fileformats.HDF5`. + file_format : `fileformats.FileFormat`, optional + File format to use. Default is `None`, i.e. guess from the name. Returns ------- @@ -517,6 +518,9 @@ def from_file( if comm is None: comm = mpiutil.world + if file_format is None: + file_format = fileformats.guess_file_format(filename) + if comm is None: if distributed: warnings.warn( @@ -595,7 +599,7 @@ def to_file( hints=True, convert_attribute_strings=True, convert_dataset_strings=False, - file_format=fileformats.HDF5, + file_format=None, **kwargs, ): """Replicate object on disk in an hdf5 or zarr file. @@ -614,9 +618,12 @@ def to_file( understands. Default is `True`. convert_dataset_strings : bool, optional Try and convert dataset string types to bytestrings. Default is `False`. - file_format : `fileformats.FileFormat` - File format to use. Default `fileformats.HDF5`. + file_format : `fileformats.FileFormat`, optional + File format to use. Default is `None`, i.e. guess from the name. """ + if file_format is None: + file_format = fileformats.guess_file_format(filename) + if not self.distributed: with file_format.open(filename, mode, **kwargs) as f: deep_group_copy( @@ -626,30 +633,14 @@ def to_file( convert_dataset_strings=convert_dataset_strings, file_format=file_format, ) - elif file_format == fileformats.HDF5: - if h5py.get_config().mpi: - _distributed_group_to_hdf5_parallel( - self, - filename, - mode, - convert_attribute_strings=convert_attribute_strings, - convert_dataset_strings=convert_dataset_strings, - ) - else: - _distributed_group_to_hdf5_serial( - self, - filename, - mode, - convert_attribute_strings=convert_attribute_strings, - convert_dataset_strings=convert_dataset_strings, - ) else: - _distributed_group_to_zarr( + _distributed_group_to_file( self, filename, mode, convert_attribute_strings=convert_attribute_strings, convert_dataset_strings=convert_dataset_strings, + file_format=file_format, ) def create_group(self, name): @@ -1678,7 +1669,7 @@ def from_file( detect_subclass=True, convert_attribute_strings=None, convert_dataset_strings=None, - file_format=fileformats.HDF5, + file_format=None, **kwargs, ): """Create data object from analysis hdf5 file, store in memory or on disk. @@ -1717,11 +1708,14 @@ def from_file( Axis selections can be given to only read a subset of the containers. A slice can be given, or a list of specific array indices for that axis. file_format : `fileformats.FileFormat` - File format to use. Default `fileformats.HDF5`. + File format to use. Default is `None`, i.e. guess from file name. **kwargs : any other arguments Any additional keyword arguments are passed to :class:`h5py.File`'s constructor if *file_* is a filename and silently ignored otherwise. """ + if file_format is None and not is_group(file_): + file_format = fileformats.guess_file_format(file_) + if file_format == fileformats.Zarr and not zarr_available: raise RuntimeError("Unable to read zarr file, please install zarr.") @@ -2432,12 +2426,15 @@ def deep_group_copy( convert_dataset_strings=False, convert_attribute_strings=True, file_format=fileformats.HDF5, + skip_distributed=False, + postprocess=None, ): """ Copy full data tree from one group to another. Copies from g1 to g2. An axis downselection can be specified by supplying the - parameter 'selections'. For example to select the first two indexes in g1["foo"]["bar"], do + parameter 'selections'. For example to select the first two indexes in + g1["foo"]["bar"], do >>> g1 = MemGroup() >>> foo = g1.create_group("foo") @@ -2463,84 +2460,142 @@ def deep_group_copy( convert_dataset_strings : bool, optional Convert strings within datasets to ensure that they are unicode. file_format : `fileformats.FileFormat` - File format to use. Default `fileformats.HDF5`. + File format to use. Default `fileformats.HDF5`. + skip_distributed : bool, optional + If `True` skip the write for any distributed dataset, and return a list of the + names of all datasets that were skipped. If `False` (default) throw a + `ValueError` if any distributed datasets are encountered. + postprocess : function, optional + A function that takes is called on each node, with the source and destination + entries, and can modify either. + + Returns + ------- + distributed_dataset_names : list + Names of the distributed datasets if `skip_distributed` is True. Otherwise + `None` is returned. """ - copyattrs(g1.attrs, g2.attrs, convert_strings=convert_attribute_strings) + distributed_dset_names = [] - # Sort to ensure consistent insertion order - for key in sorted(g1): - entry = g1[key] - if is_group(entry): - g2.create_group(key) - deep_group_copy( - entry, - g2[key], - selections, - convert_dataset_strings=convert_dataset_strings, - convert_attribute_strings=convert_attribute_strings, - file_format=file_format, + # only the case if zarr is not installed + if file_format.module is None: + raise RuntimeError("Can't deep_group_copy zarr file. Please install zarr.") + to_file = isinstance(g2, file_format.module.Group) + + # Prepare a dataset for writing out, applying selections and transforming any + # datatypes + # Returns: (dtype, shape, data_to_write) + def _prepare_dataset(dset): + + # Look for a selection for this dataset (also try without the leading "/") + try: + selection = selections.get( + dset.name, selections.get(dset.name[1:], slice(None)) ) - else: - # look for selection for this dataset (also try withouth the leading "/") - try: - selection = selections.get( - entry.name, selections.get(entry.name[1:], slice(None)) - ) - except AttributeError: - selection = slice(None) + except AttributeError: + selection = slice(None) - # only the case if zarr is not installed - if file_format.module is None: - raise RuntimeError( - "Can't deep_group_copy zarr file. Please install zarr." + # Check if this is a distributed dataset and figure out if we can make this work + # out + if to_file and isinstance(dset, MemDatasetDistributed): + if not skip_distributed: + raise ValueError( + f"Cannot write out a distributed dataset ({dset.name}) " + "via this method." + ) + elif selection != slice(None): + raise ValueError( + "Cannot apply a slice when writing out a distributed dataset " + f"({dset.name}) via this method." ) + else: + # If we get here, we should create the dataset, but not write out any data into it (i.e. return None) + distributed_dset_names.append(dset.name) + return dset.dtype, dset.shape, None - if convert_dataset_strings: - # Convert unicode strings back into ascii byte strings. This will break - # if there are characters outside of the ascii range - if isinstance(g2, file_format.module.Group): - data = ensure_bytestring(entry[selection]) + # Extract the data for the selection + data = entry[selection] - # Convert strings in an HDF5 dataset into unicode - else: - data = ensure_unicode(entry[selection]) - elif isinstance(g2, file_format.module.Group): - data = check_unicode(entry) - data = data[selection] - else: - data = entry[selection] - - # get compression options/chunking for this dataset - chunks = getattr(entry, "chunks", None) - compression = getattr(entry, "compression", None) - compression_opts = getattr(entry, "compression_opts", None) - - # TODO: Am I missing something or is this branch not necessary? - # I guess I'm still confused as to why a file_format is - # required even for the in-memory case - if isinstance(g2, file_format.module.Group): - compression_kwargs = file_format.compression_kwargs( - compression=compression, - compression_opts=compression_opts, - compressor=getattr(entry, "compressor", None), - ) + if convert_dataset_strings: + # Convert unicode strings back into ascii byte strings. This will break + # if there are characters outside of the ascii range + if to_file: + data = ensure_bytestring(data) + + # Convert strings in an HDF5 dataset into unicode else: - # in-memory case; use HDF5 compression args format for this case - compression_kwargs = fileformats.HDF5.compression_kwargs( - compression=compression, compression_opts=compression_opts - ) + data = ensure_unicode(data) + elif to_file: + # If we shouldn't convert we at least need to ensure there aren't any + # Unicode characters before writing + data = check_unicode(entry) + + return data.dtype, data.shape, data + + # get compression options/chunking for this dataset + # Returns dict of compression and chunking arguments for create_dataset + def _prepare_compression_args(dset): + compression = getattr(dset, "compression", None) + compression_opts = getattr(dset, "compression_opts", None) + + if to_file: + # massage args according to file format + compression_kwargs = file_format.compression_kwargs( + compression=compression, + compression_opts=compression_opts, + compressor=getattr(dset, "compressor", None), + ) + else: + # in-memory case; use HDF5 compression args format for this case + compression_kwargs = fileformats.HDF5.compression_kwargs( + compression=compression, compression_opts=compression_opts + ) + compression_kwargs["chunks"] = getattr(dset, "chunks", None) + + # disable compression if not enabled for HDF5 files + # https://github.com/chime-experiment/Pipeline/issues/33 + if ( + to_file + and file_format == fileformats.HDF5 + and not fileformats.HDF5.compression_enabled() + and isinstance(dset, MemDatasetDistributed) + ): + compression_kwargs = {} + + return compression_kwargs + + # Do a non-recursive traversal of the tree, recreating the structure and attributes, + # and copying over any non-distributed datasets + stack = [g1] + while stack: + + entry = stack.pop() + key = entry.name + + if is_group(entry): + if key != g1.name: + # Only create group if we are above the starting level + g2.create_group(key) + stack += [entry[k] for k in sorted(entry, reverse=True)] + else: # Is a dataset + dtype, shape, data = _prepare_dataset(entry) + compression_kwargs = _prepare_compression_args(entry) g2.create_dataset( key, - shape=data.shape, - dtype=data.dtype, + shape=shape, + dtype=dtype, data=data, - chunks=chunks, **compression_kwargs, ) - copyattrs( - entry.attrs, g2[key].attrs, convert_strings=convert_attribute_strings - ) + + target = g2[key] + copyattrs(entry.attrs, target.attrs, convert_strings=convert_attribute_strings) + + if postprocess: + postprocess(entry, target) + + return distributed_dset_names if skip_distributed else None def format_abs_path(path): @@ -2559,398 +2614,99 @@ def format_abs_path(path): return out -def _distributed_group_to_hdf5_serial( +def _distributed_group_to_file( group, fname, mode, hints=True, convert_dataset_strings=False, convert_attribute_strings=True, + file_format=None, + serial=False, **kwargs, ): - """Private routine to copy full data tree from distributed memh5 object - into an HDF5 file. + """Copy full data tree from distributed memh5 object into the destination file. - This version explicitly serialises all IO. - """ + This routine works in two stages: - if not group.distributed: - raise RuntimeError( - "This should only run on distributed datasets [%s]." % group.name - ) + - First rank=0 copies all of the groups, attributes and non-distributed datasets + into the target file. The distributed datasets are identified and created in this + step, but their contents are not written. This is done by `deep_group_copy` to try + and centralize as much of the copying code. + - In the second step, the distributed datasets are written to disk. This is mostly + offloaded to `MPIArray.to_file`, but some code around this needs to change depending + on the file type, and if the data can be written in parallel. + """ comm = group.comm - # Create group (or file) - if comm.rank == 0: - - # If this is the root group, create the file and copy the file level attrs - if group.name == "/": - with h5py.File(fname, mode, **kwargs) as f: - copyattrs( - group.attrs, f.attrs, convert_strings=convert_attribute_strings - ) - - if hints: - f.attrs["__memh5_distributed_file"] = True + def apply_hints(source, dest): + if dest.name == "/": + dest.attrs["__memh5_distributed_file"] = True + elif isinstance(source, MemDatasetCommon): + dest.attrs["__memh5_distributed_dset"] = False + elif isinstance(source, MemDatasetDistributed): + dest.attrs["__memh5_distributed_dset"] = True - # Create this group and copy attrs - else: - with h5py.File(fname, "r+", **kwargs) as f: - g = f.create_group(group.name) - copyattrs( - group.attrs, g.attrs, convert_strings=convert_attribute_strings - ) - - comm.Barrier() - - # Write out groups and distributed datasets, these operations must be done - # collectively - # Sort to ensure insertion order is identical - for key in sorted(group): - - entry = group[key] - - # Groups are written out by recursing - if is_group(entry): - _distributed_group_to_hdf5_serial( - entry, - fname, - mode, + # Walk the full structure and separate out what we need to write + if comm.rank == 0: + with file_format.open(fname, mode) as fh: + distributed_dataset_names = deep_group_copy( + group, + fh, convert_dataset_strings=convert_dataset_strings, convert_attribute_strings=convert_attribute_strings, - **kwargs, + skip_distributed=True, + file_format=file_format, + postprocess=(apply_hints if hints else None), ) + else: + distributed_dataset_names = None - # Write out distributed datasets (only the data, the attributes are written below) - elif isinstance(entry, MemDatasetDistributed): + distributed_dataset_names = comm.bcast(distributed_dataset_names) - arr = check_unicode(entry) + def _write_distributed_datasets(dest): + for name in distributed_dataset_names: + dset = group[name] + data = check_unicode(dset) - if fileformats.HDF5.compression_enabled(): - ( - chunks, - compression_kwargs, - ) = entry.chunks, fileformats.HDF5.compression_kwargs( - compression=entry.compression, - compression_opts=entry.compression_opts, - ) - else: - # disable compression if not enabled for HDF5 files - # https://github.com/chime-experiment/Pipeline/issues/33 - chunks, compression_kwargs = None, { - "compression": None, - "compression_opts": None, - } - - arr.to_hdf5( - fname, - entry.name, - chunks=chunks, - **compression_kwargs, + data.to_file( + dest, + name, + chunks=dset.chunks, + compression=dset.compression, + compression_opts=dset.compression_opts, + file_format=file_format, ) - comm.Barrier() - # Write out common datasets, and the attributes on distributed datasets - if comm.rank == 0: - - with h5py.File(fname, "r+", **kwargs) as f: - - for key, entry in group.items(): - - # Write out common datasets and copy their attrs - if isinstance(entry, MemDatasetCommon): - - # Deal with unicode numpy datasets that aren't supported by HDF5 - if convert_dataset_strings: - # Attempt to coerce to a type that HDF5 supports - data = ensure_bytestring(entry.data) - else: - data = check_unicode(entry) - - # allow chunks and compression bc serialised IO - dset = f.create_dataset( - entry.name, - data=data, - chunks=entry.chunks, - **fileformats.HDF5.compression_kwargs( - compression=entry.compression, - compression_opts=entry.compression_opts, - ), - ) - copyattrs( - entry.attrs, - dset.attrs, - convert_strings=convert_attribute_strings, - ) - - if hints: - dset.attrs["__memh5_distributed_dset"] = False - - # Copy the attributes over for a distributed dataset - elif isinstance(entry, MemDatasetDistributed): - - if entry.name not in f: - raise RuntimeError( - "Distributed dataset should already have been created." - ) - - copyattrs( - entry.attrs, - f[entry.name].attrs, - convert_strings=convert_attribute_strings, - ) - - if hints: - f[entry.name].attrs["__memh5_distributed_dset"] = True - f[entry.name].attrs[ - "__memh5_distributed_axis" - ] = entry.distributed_axis - - comm.Barrier() - - -def _distributed_group_to_hdf5_parallel( - group, - fname, - mode, - hints=True, - convert_dataset_strings=False, - convert_attribute_strings=True, - **_, -): - """Private routine to copy full data tree from distributed memh5 object - into an HDF5 file. - This version paralellizes all IO.""" - - # == Create some internal functions for doing the read == - # Function to perform a recursive clone of the tree structure - def _copy_to_file(memgroup, h5group): - - # Copy over attributes - copyattrs( - memgroup.attrs, h5group.attrs, convert_strings=convert_attribute_strings - ) - - # Sort the items to ensure we insert in a consistent order across ranks - for key in sorted(memgroup): - - item = memgroup[key] - - # If group, create the entry and the recurse into it - if is_group(item): - new_group = h5group.create_group(key) - _copy_to_file(item, new_group) - - # If dataset, create dataset - else: - - # Check if we are in a distributed dataset - if isinstance(item, MemDatasetDistributed): - - data = check_unicode(item) - - # Write to file from MPIArray - if fileformats.HDF5.compression_enabled(): - chunks, compression, compression_opts = ( - item.chunks, - item.compression, - item.compression_opts, - ) - else: - # disable compression if not enabled for HDF5 files - # https://github.com/chime-experiment/Pipeline/issues/33 - chunks, compression, compression_opts = None, None, None - - data.to_hdf5( - h5group, - key, - chunks=chunks, - compression=compression, - compression_opts=compression_opts, - ) - - dset = h5group[key] - - if hints: - dset.attrs["__memh5_distributed_dset"] = True - dset.attrs["__memh5_distributed_axis"] = item.distributed_axis - - # Create common dataset (collective) - else: - - # Convert from unicode to bytestring - if convert_dataset_strings: - data = ensure_bytestring(item.data) - else: - data = check_unicode(item) - - if fileformats.HDF5.compression_enabled(): - ( - chunks, - compression_kwargs, - ) = item.chunks, fileformats.HDF5.compression_kwargs( - item.compression, item.compression_opts - ) - else: - # disable compression if not enabled for HDF5 files - # https://github.com/chime-experiment/Pipeline/issues/33 - chunks, compression_kwargs = None, { - "compression": None, - "compression_opts": None, - } - - dset = h5group.create_dataset( - key, - shape=data.shape, - dtype=data.dtype, - chunks=chunks, - **compression_kwargs, - ) - - # Write common data from rank 0 - if memgroup.comm.rank == 0: - dset[:] = data - - if hints: - dset.attrs["__memh5_distributed_dset"] = False - - # Copy attributes over into dataset - copyattrs( - item.attrs, dset.attrs, convert_strings=convert_attribute_strings - ) - - # Open file on all ranks - with misc.open_h5py_mpi(fname, mode, comm=group.comm) as f: - if not f.is_mpi: - raise RuntimeError("Could not create file %s in MPI mode" % fname) - - # Start recursive file write - _copy_to_file(group, f) - - if hints: - f.attrs["__memh5_distributed_file"] = True - - # Final synchronisation - group.comm.Barrier() - - -def _distributed_group_to_zarr( - group, - fname, - mode, - hints=True, - convert_dataset_strings=False, - convert_attribute_strings=True, - **_, -): - """Private routine to copy full data tree from distributed memh5 object into a Zarr file. - - This paralellizes all IO.""" - - if not zarr_available: - raise RuntimeError("Can't write to zarr file. Please install zarr.") - - # == Create some internal functions for doing the read == - # Function to perform a recursive clone of the tree structure - def _copy_to_file(memgroup, group): - - # Copy over attributes - if memgroup.comm.rank == 0: - copyattrs( - memgroup.attrs, group.attrs, convert_strings=convert_attribute_strings - ) - - # Sort the items to ensure we insert in a consistent order across ranks - for key in sorted(memgroup): - - item = memgroup[key] - - # If group, create the entry and the recurse into it - if is_group(item): - if memgroup.comm.rank == 0: - group.create_group(key) - memgroup.comm.Barrier() - _copy_to_file(item, group[key]) - - # If dataset, create dataset - else: - # Check if we are in a distributed dataset - if isinstance(item, MemDatasetDistributed): - - data = check_unicode(item) - - logger.error(f"chunk settings: {item.chunks}") - - # Write to file from MPIArray - data.to_file( - group, - key, - chunks=item.chunks, - compression=item.compression, - compression_opts=item.compression_opts, - file_format=fileformats.Zarr, - ) - dset = group[key] - - if memgroup.comm.rank == 0 and hints: - dset.attrs["__memh5_distributed_dset"] = True - - # Create common dataset (collective) - else: - - # Convert from unicode to bytestring - if convert_dataset_strings: - data = ensure_bytestring(item.data) - else: - data = check_unicode(item) - - # Write common data from rank 0 - if memgroup.comm.rank == 0: - dset = group.create_dataset( - key, - shape=data.shape, - dtype=data.dtype, - chunks=item.chunks, - **fileformats.Zarr.compression_kwargs( - item.compression, item.compression_opts - ), - ) - - dset[:] = data - - if hints: - dset.attrs["__memh5_distributed_dset"] = False - - # Copy attributes over into dataset - if memgroup.comm.rank == 0: - copyattrs( - item.attrs, - dset.attrs, - convert_strings=convert_attribute_strings, - ) - - # Make sure file exists - if group.comm.rank == 0: - zarr.open_group(store=fname, mode=mode) - group.comm.Barrier() - - # Open file on all ranks - - with fileformats.ZarrProcessSynchronizer( - f".{fname}.sync", group.comm - ) as synchronizer, zarr.open_group( - store=fname, mode="r+", synchronizer=synchronizer - ) as f: - # Start recursive file write - _copy_to_file(group, f) - - if hints and group.comm.rank == 0: - f.attrs["__memh5_distributed_file"] = True + # Write out the distributed parts of the file, this needs to be done slightly + # differently depending on the actual format we want to use (and if HDF5+MPI is + # available) + # NOTE: need to use mode r+ as the file should already exist + if file_format == fileformats.Zarr: + + with fileformats.ZarrProcessSynchronizer( + f".{fname}.sync", group.comm + ) as synchronizer, zarr.open_group( + store=fname, mode="r+", synchronizer=synchronizer + ) as f: + _write_distributed_datasets(f) + + elif file_format == fileformats.HDF5: + + # Use MPI IO if possible, else revert to serialising + if h5py.get_config().mpi: + # Open file on all ranks + with misc.open_h5py_mpi(fname, "r+", comm=group.comm) as f: + if not f.is_mpi: + raise RuntimeError("Could not create file %s in MPI mode" % fname) + _write_distributed_datasets(f) + else: + _write_distributed_datasets(fname) - # Final synchronisation - group.comm.Barrier() + else: + raise ValueError(f"Unknown format={file_format}") def _distributed_group_from_file( diff --git a/caput/mpiarray.py b/caput/mpiarray.py index e3e3ece6..77aaaf8c 100644 --- a/caput/mpiarray.py +++ b/caput/mpiarray.py @@ -821,7 +821,8 @@ def to_hdf5( # Split the axis to get the IO size under ~2GB (only if MPI-IO) split_axis, partitions = self._partition_io(skip=(not fh.is_mpi)) - fh.create_dataset( + dset = _create_or_get_dset( + fh, dataset, shape=self.global_shape, dtype=self.dtype, @@ -831,14 +832,14 @@ def to_hdf5( ) # Read using collective MPI-IO if specified - with fh[dataset].collective if use_collective else DummyContext(): + with dset.collective if use_collective else DummyContext(): # Loop over partitions of the IO and perform them for part in partitions: islice, fslice = _partition_sel( sel, split_axis, self.global_shape[split_axis], part ) - fh[dataset][islice] = self[fslice] + dset[islice] = self[fslice] if fh.opened: fh.close() @@ -914,7 +915,8 @@ def to_zarr( split_axis, partitions = self._partition_io(skip=True) if self.comm.rank == 0: - group.create_dataset( + _create_or_get_dset( + group, dataset, shape=self.global_shape, dtype=self.dtype, @@ -1173,7 +1175,7 @@ def _to_hdf5_serial(self, filename, dataset, create=False): filename : str File to write dataset into. dataset : string - Name of dataset to write into. Should not exist. + Name of dataset to write into. """ ## Naive non-parallel implementation to start @@ -1191,11 +1193,13 @@ def _to_hdf5_serial(self, filename, dataset, create=False): if self.comm is None or self.comm.rank == 0: with h5py.File(filename, "a" if create else "r+") as fh: - if dataset in fh: - raise Exception("Dataset should not exist.") - - fh.create_dataset(dataset, self.global_shape, dtype=self.dtype) - fh[dataset][:] = np.array(0.0).astype(self.dtype) + dset = _create_or_get_dset( + fh, + dataset, + self.global_shape, + dtype=self.dtype, + ) + dset[:] = np.array(0.0).astype(self.dtype) # wait until all processes see the created file while not os.path.exists(filename): @@ -1346,6 +1350,31 @@ def _expand_sel(sel, naxis): return list(sel) +def _create_or_get_dset(group, name, shape, dtype, **kwargs): + # Create a dataset if it doesn't exist, or test the existing one for compatibility + # and return + if name in group: + dset = group[name] + if dset.shape != shape: + raise RuntimeError( + "Dataset exists already but with incompatible shape." + f"Requested shape={shape}, but on disk shape={dset.shape}." + ) + if dset.dtype != dtype: + raise RuntimeError( + "Dataset exists already but with incompatible dtype. " + f"Requested dtype={dtype}, on disk dtype={dset.dtype}." + ) + else: + dset = group.create_dataset( + name, + shape=shape, + dtype=dtype, + **kwargs, + ) + return dset + + class DummyContext: """A completely dummy context manager.""" diff --git a/caput/tests/test_memh5.py b/caput/tests/test_memh5.py index a329fd64..bcdaa46c 100644 --- a/caput/tests/test_memh5.py +++ b/caput/tests/test_memh5.py @@ -150,6 +150,8 @@ def test_file_sanity(test_file, file_open_function): @pytest.mark.parametrize( "test_file,file_open_function,file_format", [ + (lazy_fixture("filled_h5_file"), h5py.File, None), + (lazy_fixture("filled_zarr_file"), zarr.open_group, None), (lazy_fixture("filled_h5_file"), h5py.File, fileformats.HDF5), (lazy_fixture("filled_zarr_file"), zarr.open_group, fileformats.Zarr), ], @@ -162,13 +164,11 @@ def test_to_from_file(test_file, file_open_function, file_format): with file_open_function(test_file, "r") as f: assertGroupsEqual(f, m) - m.to_file( - test_file + ".new", - file_format=file_format, - ) + new_name = f"new.{test_file}" + m.to_file(new_name, file_format=file_format) # Check that written file has same structure - with file_open_function(test_file + ".new", "r") as f: + with file_open_function(new_name, "r") as f: assertGroupsEqual(f, m) @@ -256,6 +256,8 @@ class TempSubClass(memh5.MemDiskGroup): [ (lazy_fixture("h5_file"), fileformats.HDF5), (lazy_fixture("zarr_file"), fileformats.Zarr), + (lazy_fixture("h5_file"), None), + (lazy_fixture("zarr_file"), None), ], ) def test_io(test_file, file_format): @@ -265,6 +267,8 @@ def test_io(test_file, file_format): tsc.create_dataset("dset", data=np.arange(10)) tsc.save(test_file, file_format=file_format) + actual_file_format = fileformats.guess_file_format(test_file) + # Load it from disk tsc2 = memh5.MemDiskGroup.from_file(test_file, file_format=file_format) tsc3 = memh5.MemDiskGroup.from_file(test_file, ondisk=True, file_format=file_format) @@ -283,18 +287,18 @@ def test_io(test_file, file_format): test_file, mode="r", ondisk=True, file_format=file_format ): # h5py will error if file already open - if file_format == fileformats.HDF5: + if actual_file_format == fileformats.HDF5: with pytest.raises(IOError): - file_format.open(test_file, "w") + actual_file_format.open(test_file, "w") # ...zarr will not else: - file_format.open(test_file, "w") + actual_file_format.open(test_file, "w") with memh5.MemDiskGroup.from_file( test_file, mode="r", ondisk=False, file_format=file_format ): - f = file_format.open(test_file, "w") - if file_format == fileformats.HDF5: + f = actual_file_format.open(test_file, "w") + if actual_file_format == fileformats.HDF5: f.close() From 0562e08a5f8fde3143c948d9ff76a5771bcc5afe Mon Sep 17 00:00:00 2001 From: Rick Nitsche Date: Fri, 4 Jun 2021 15:52:18 -0700 Subject: [PATCH 3/7] feat(truncate): add truncate for double values Co-authored-by: Tristan Pinsonneault-Marotte --- caput/tests/test_truncate.py | 77 +++++++++++++++ caput/truncate.hpp | 180 +++++++++++++++++++++++++++++++++-- caput/truncate.pyx | 171 ++++++++++++++++++++++++++++++--- 3 files changed, 407 insertions(+), 21 deletions(-) create mode 100644 caput/tests/test_truncate.py diff --git a/caput/tests/test_truncate.py b/caput/tests/test_truncate.py new file mode 100644 index 00000000..007f1aae --- /dev/null +++ b/caput/tests/test_truncate.py @@ -0,0 +1,77 @@ +import numpy as np + +from caput import truncate + + +def test_bit_truncate(): + assert truncate.bit_truncate_int(129, 1) == 128 + + assert truncate.bit_truncate_long(129, 1) == 128 + assert truncate.bit_truncate_long(576460752303423489, 1) == 576460752303423488 + assert ( + truncate.bit_truncate_long(4520628863461491, 140737488355328) + == 4503599627370496 + ) + + assert truncate.bit_truncate_int(54321, 0) == 54321 + + assert truncate.bit_truncate_long(576460752303423489, 0) == 576460752303423489 + + +def test_truncate_float(): + assert truncate.bit_truncate_float(32.121, 1) == 32 + # fails assert truncate.bit_truncate_float(float(0.010101), 0) == float(0.010101) + + assert truncate.bit_truncate_double(32.121, 1) == 32 + assert truncate.bit_truncate_double(0.9191919191, 0) == 0.9191919191 + + +def test_truncate_array(): + assert ( + truncate.bit_truncate_relative( + np.asarray([32.121, 32.5], dtype=np.float32), 1 / 32 + ) + == np.asarray([32, 32], dtype=np.float32) + ).all() + assert ( + truncate.bit_truncate_relative_double( + np.asarray([32.121, 32.5], dtype=np.float64), 1 / 32 + ) + == np.asarray([32, 32], dtype=np.float64) + ).all() + + +def test_truncate_weights(): + assert ( + truncate.bit_truncate_weights( + np.asarray([32.121, 32.5], dtype=np.float32), + np.asarray([1 / 32, 1 / 32], dtype=np.float32), + 0.001, + ) + == np.asarray([32, 32], dtype=np.float32) + ).all() + assert ( + truncate.bit_truncate_weights( + np.asarray([32.121, 32.5], dtype=np.float64), + np.asarray([1 / 32, 1 / 32], dtype=np.float64), + 0.001, + ) + == np.asarray([32, 32], dtype=np.float64) + ).all() + + +def test_truncate_relative(): + assert ( + truncate.bit_truncate_relative( + np.asarray([32.121, 32.5], dtype=np.float32), + 0.1, + ) + == np.asarray([32, 32], dtype=np.float32) + ).all() + assert ( + truncate.bit_truncate_relative( + np.asarray([32.121, 32.5], dtype=np.float64), + 0.1, + ) + == np.asarray([32, 32], dtype=np.float64) + ).all() diff --git a/caput/truncate.hpp b/caput/truncate.hpp index 8f1a4d23..b728e0dc 100644 --- a/caput/truncate.hpp +++ b/caput/truncate.hpp @@ -1,8 +1,45 @@ -//#include - // 2**31 + 2**30 will be used to check for overflow const uint32_t HIGH_BITS = 3221225472; +// 2**63 + 2**62 will be used to check for overflow +const uint64_t HIGH_BITS_DOUBLE = 13835058055282163712UL; + +// The length of the part in a float that represents the exponent +const int32_t LEN_EXPONENT_FLOAT = 8; + +// The length of the part in a double that represents the exponent +const int64_t LEN_EXPONENT_DOUBLE = 11; + +// Starting bit (offset) of the part in a float that represents the exponent +const int32_t POS_EXPONENT_FLOAT = 23; + +// Starting bit (offset) of the part in a double that represents the exponent +const int64_t POS_EXPONENT_DOUBLE = 52; + +// A mask to apply on the exponent representation of a float, to get rid of the sign part +const int32_t MASK_EXPONENT_W_O_SIGN_FLOAT = 255; + +// A mask to apply on the exponent representation of a double, to get rid of the sign part +const int64_t MASK_EXPONENT_W_O_SIGN_DOUBLE = 2047; + +// A mask to apply on a float to get only the mantissa (2**23 - 1) +const int32_t MASK_MANTISSA_FLOAT = 8388607; + +// A mask to apply on a double to get only the mantissa (2**52 - 1) +const int64_t MASK_MANTISSA_DOUBLE = 4503599627370495L; + +// Implicit 24th bit of the mantissa in a float (2**23) +const int32_t IMPLICIT_BIT_FLOAT = 8388608; + +// Implicit 53rt bit of the mantissa in a double (2**52) +const int64_t IMPLICIT_BIT_DOUBLE = 4503599627370496L; + +// The maximum error we can have for the mantissa in a float (less than 2**30) +const int32_t ERR_MAX_FLOAT = 1073741823; + +// The maximum error we can have for the mantissa in a double (less than 2**30) +const int64_t ERR_MAX_DOUBLE = 4611686018427387903L; + /** * @brief Truncate the precision of *val* by rounding to a multiple of a power of * two, keeping error less than or equal to *err*. @@ -40,6 +77,44 @@ inline int32_t bit_truncate(int32_t val, int32_t err) { } +/** + * @brief Truncate the precision of *val* by rounding to a multiple of a power of + * two, keeping error less than or equal to *err*. + * + * @warning Undefined results for err < 0 and err > 2**62. + */ +inline int64_t bit_truncate_64(int64_t val, int64_t err) { + // *gran* is the granularity. It is the power of 2 that is *larger than* the + // maximum error *err*. + int64_t gran = err; + gran |= gran >> 1; + gran |= gran >> 2; + gran |= gran >> 4; + gran |= gran >> 8; + gran |= gran >> 16; + gran |= gran >> 32; + gran += 1; + + // Bitmask selects bits to be rounded. + int64_t bitmask = gran - 1; + + // Determine if there is a round-up/round-down tie. + // This operation gets the `gran = 1` case correct (non tie). + int64_t tie = ((val & bitmask) << 1) == gran; + + // The acctual rounding. + int64_t val_t = (val - (gran >> 1)) | bitmask; + val_t += 1; + // There is a bit of extra bit twiddling for the err == 0. + val_t -= (err == 0); + + // Break any tie by rounding to even. + val_t -= val_t & (tie * gran); + + return val_t; +} + + /** * @brief Count the number of leading zeros in a binary number. * Taken from https://stackoverflow.com/a/23857066 @@ -54,6 +129,21 @@ inline int32_t count_zeros(int32_t x) { } +/** + * @brief Count the number of leading zeros in a binary number. + * Taken from https://stackoverflow.com/a/23857066 + */ +inline int64_t count_zeros_64(int64_t x) { + x = x | (x >> 1); + x = x | (x >> 2); + x = x | (x >> 4); + x = x | (x >> 8); + x = x | (x >> 16); + x = x | (x >> 32); + return __builtin_popcountl(~x); +} + + /** * @brief Fast power of two float. * @@ -73,6 +163,25 @@ inline float fast_pow(int8_t e) { } +/** + * @brief Fast power of two double. + * + * Result is undefined for e < -1022 and e > 1023. + * + * @param e Exponent + * + * @returns The result of 2^e + */ +inline double fast_pow_double(int16_t e) { + double* out_f; + // Construct float bitwise + uint64_t out_i = ((uint64_t)(1023 + e) << 52); + // Cast into float + out_f = (double*)&out_i; + return *out_f; +} + + /** * @brief Truncate precision of a floating point number by applying the algorithm of * `bit_truncate` to the mantissa. @@ -82,22 +191,22 @@ inline float fast_pow(int8_t e) { * happens to remove all of the non-zero bits in the mantissa, a NaN can become inf. * */ -inline float bit_truncate_float(float val, float err) { +inline float _bit_truncate_float(float val, float err) { // cast float memory into an int int32_t* cast_val_ptr = (int32_t*)&val; // extract the exponent and sign - int32_t val_pre = cast_val_ptr[0] >> 23; + int32_t val_pre = cast_val_ptr[0] >> POS_EXPONENT_FLOAT; // strip sign - int32_t val_pow = val_pre & 255; - int32_t val_s = val_pre >> 8; + int32_t val_pow = val_pre & MASK_EXPONENT_W_O_SIGN_FLOAT; + int32_t val_s = val_pre >> LEN_EXPONENT_FLOAT; // extract mantissa. mask is 2**23 - 1. Add back the implicit 24th bit - int32_t val_man = (cast_val_ptr[0] & 8388607) + 8388608; + int32_t val_man = (cast_val_ptr[0] & MASK_MANTISSA_FLOAT) + IMPLICIT_BIT_FLOAT; // scale the error to the integer representation of the mantissa // scale by 2**(23 + 127 - pow) int32_t int_err = (int32_t)(err * fast_pow(150 - val_pow)); // make sure hasn't overflowed. if set to 2**30-1, will surely round to 0. // must keep err < 2**30 for bit_truncate to work - int_err = (int_err & HIGH_BITS) ? 1073741823 : int_err; + int_err = (int_err & HIGH_BITS) ? ERR_MAX_FLOAT : int_err; // truncate int32_t tr_man = bit_truncate(val_man, int_err); @@ -107,7 +216,7 @@ inline float bit_truncate_float(float val, float err) { // adjust power after truncation to account for loss of implicit bit val_pow -= z_count - 8; // shift mantissa by same amount, remove implicit bit - tr_man = (tr_man << (z_count - 8)) & 8388607; + tr_man = (tr_man << (z_count - 8)) & MASK_MANTISSA_FLOAT; // round to zero case val_pow = ((z_count != 32) ? val_pow : 0); // restore sign and exponent @@ -117,3 +226,56 @@ inline float bit_truncate_float(float val, float err) { return tr_val_ptr[0]; } + + +/** + * @brief Truncate precision of a double floating point number by applying the algorithm of + * `bit_truncate` to the mantissa. + * + * Note that NaN and inf are not explicitly checked for. According to the IEEE spec, it is + * impossible for the truncation to turn an inf into a NaN. However, if the truncation + * happens to remove all of the non-zero bits in the mantissa, a NaN can become inf. + * + */ +inline double _bit_truncate_double(double val, double err) { + // Step 1: Extract the sign, exponent and mantissa: + // ------------------------------------------------ + // cast float memory into an int + int64_t* cast_val_ptr = (int64_t*)&val; + // extract the exponent and sign + int64_t val_pre = cast_val_ptr[0] >> POS_EXPONENT_DOUBLE; + // strip sign + int64_t val_pow = val_pre & MASK_EXPONENT_W_O_SIGN_DOUBLE; + int64_t val_s = val_pre >> LEN_EXPONENT_DOUBLE; + // extract mantissa. mask is 2**52 - 1. Add back the implicit 53rd bit + int64_t val_man = (cast_val_ptr[0] & MASK_MANTISSA_DOUBLE) + IMPLICIT_BIT_DOUBLE; + + // Step 2: Scale the error to the integer representation of the mantissa: + // ---------------------------------------------------------------------- + // scale by 2**(52 + 1023 - pow) + int64_t int_err = (int64_t)(err * fast_pow_double(1075 - val_pow)); + // make sure hasn't overflowed. if set to 2**62-1, will surely round to 0. + // must keep err < 2**62 for bit_truncate_double to work + int_err = (int_err & HIGH_BITS_DOUBLE) ? ERR_MAX_DOUBLE : int_err; + + // Step 3: Truncate the mantissa: + // ------------------------------ + int64_t tr_man = bit_truncate_64(val_man, int_err); + + // Step 4: Put it back together: + // ----------------------------- + // count leading zeros + int64_t z_count = count_zeros_64(tr_man); + // adjust power after truncation to account for loss of implicit bit + val_pow -= z_count - 11; + // shift mantissa by same amount, remove implicit bit + tr_man = (tr_man << (z_count - 11)) & MASK_MANTISSA_DOUBLE; + // round to zero case + val_pow = ((z_count != 64) ? val_pow : 0); + // restore sign and exponent + int64_t tr_val = tr_man | ((val_pow | (val_s << 11)) << 52); + // cast back to double + double* tr_val_ptr = (double*)&tr_val; + + return tr_val_ptr[0]; +} diff --git a/caput/truncate.pyx b/caput/truncate.pyx index 93424181..6df8d346 100644 --- a/caput/truncate.pyx +++ b/caput/truncate.pyx @@ -1,23 +1,56 @@ """Routines for truncating data to a specified precision.""" +# cython: language_level=3 + cimport cython from cython.parallel import prange import numpy as np cimport numpy as cnp +cdef extern from "truncate.hpp": + inline int bit_truncate(int val, int err) nogil + +cdef extern from "truncate.hpp": + inline long bit_truncate_64(long val, long err) nogil cdef extern from "truncate.hpp": - inline float bit_truncate_float(float val, float err) nogil + inline float _bit_truncate_float(float val, float err) nogil +cdef extern from "truncate.hpp": + inline double _bit_truncate_double(double val, double err) nogil + ctypedef double complex complex128 cdef extern from "complex.h" nogil: double cabs(complex128) -def bit_truncate(float val, float err): +def bit_truncate_int(int val, int err): + """ + Bit truncation of a 32bit integer. + + Truncate the precision of `val` by rounding to a multiple of a power of + two, keeping error less than or equal to `err`. + + Made available for testing. + """ + return bit_truncate(val, err) + +def bit_truncate_long(long val, long err): + """ + Bit truncation of a 64bit integer. + + Truncate the precision of `val` by rounding to a multiple of a power of + two, keeping error less than or equal to `err`. + + Made available for testing. + """ + return bit_truncate_64(val, err) + + +def bit_truncate_float(float val, float err): """Truncate using a fixed error. Parameters @@ -33,16 +66,90 @@ def bit_truncate(float val, float err): The truncated value. """ - return bit_truncate_float(val, err) + return _bit_truncate_float(val, err) + + +def bit_truncate_double(double val, double err): + """Truncate using a fixed error. + + Parameters + ---------- + val + The value to truncate. + err + The absolute precision to allow. + + Returns + ------- + val + The truncated value. + """ + + return _bit_truncate_double(val, err) + + +def bit_truncate_weights(val, inv_var, fallback): + if val.dtype == np.float32 and inv_var.dtype == np.float32: + return bit_truncate_weights_float(val, inv_var, fallback) + if val.dtype == np.float64 and inv_var.dtype == np.float64: + return bit_truncate_weights_double(val, inv_var, fallback) + else: + raise RuntimeError(f"Can't truncate data of type {val.dtype}/{inv_var.dtype} " + f"(expected float32 or float64).") @cython.boundscheck(False) @cython.wraparound(False) -def bit_truncate_weights(float[::1] val, float[::1] inv_var, float fallback): +def bit_truncate_weights_float(float[:] val, float[:] inv_var, float fallback): """Truncate using a set of inverse variance weights. Giving the error as an inverse variance is particularly useful for data analysis. + N.B. non-contiguous arrays are supported in order to allow real and imaginary parts + of numpy arrays to be truncated without making a copy. + + Parameters + ---------- + val + The array of values to truncate the precision of. These values are modified in place. + inv_var + The acceptable precision expressed as an inverse variance. + fallback + A relative precision to use for cases where the inv_var is zero. + + Returns + ------- + val + The modified array. This shares the same underlying memory as the input. + """ + cdef Py_ssize_t n = val.shape[0] + cdef Py_ssize_t i = 0 + + if val.ndim != 1: + raise ValueError("Input array must be 1-d.") + if inv_var.shape[0] != n: + raise ValueError( + f"Weight and value arrays must have same shape ({inv_var.shape[0]} != {n})" + ) + + for i in prange(n, nogil=True): + if inv_var[i] != 0: + val[i] = _bit_truncate_float(val[i], 1.0 / inv_var[i]**0.5) + else: + val[i] = _bit_truncate_float(val[i], fallback * val[i]) + + return np.asarray(val) + +@cython.boundscheck(False) +@cython.wraparound(False) +def bit_truncate_weights_double(double[:] val, double[:] inv_var, double fallback): + """Truncate array of doubles using a set of inverse variance weights. + + Giving the error as an inverse variance is particularly useful for data analysis. + + N.B. non-contiguous arrays are supported in order to allow real and imaginary parts + of numpy arrays to be truncated without making a copy. + Parameters ---------- val @@ -69,18 +176,29 @@ def bit_truncate_weights(float[::1] val, float[::1] inv_var, float fallback): for i in prange(n, nogil=True): if inv_var[i] != 0: - val[i] = bit_truncate_float(val[i], 1.0 / inv_var[i]**0.5) + val[i] = _bit_truncate_double(val[i], 1.0 / inv_var[i]**0.5) else: - val[i] = bit_truncate_float(val[i], fallback * val[i]) + val[i] = _bit_truncate_double(val[i], fallback * val[i]) return np.asarray(val) +def bit_truncate_relative(val, prec): + if val.dtype == np.float32: + return bit_truncate_relative_float(val, prec) + if val.dtype == np.float64: + return bit_truncate_relative_double(val, prec) + else: + raise RuntimeError(f"Can't truncate data of type {val.dtype} (expected float32 or float64).") + @cython.boundscheck(False) @cython.wraparound(False) -def bit_truncate_relative(float[::1] val, float prec): +def bit_truncate_relative_float(float[:] val, float prec): """Truncate using a relative tolerance. + N.B. non-contiguous arrays are supported in order to allow real and imaginary parts + of numpy arrays to be truncated without making a copy. + Parameters ---------- val @@ -97,14 +215,43 @@ def bit_truncate_relative(float[::1] val, float prec): cdef Py_ssize_t i = 0 for i in prange(n, nogil=True): - val[i] = bit_truncate_float(val[i], prec * val[i]) + val[i] = _bit_truncate_float(val[i], prec * val[i]) return np.asarray(val) @cython.boundscheck(False) @cython.wraparound(False) -def bit_truncate_max_complex(complex128[:, ::1] val, float prec, float prec_max_row): +def bit_truncate_relative_double(cnp.float64_t[:] val, cnp.float64_t prec): + """Truncate doubles using a relative tolerance. + + N.B. non-contiguous arrays are supported in order to allow real and imaginary parts + of numpy arrays to be truncated without making a copy. + + Parameters + ---------- + val + The array of double values to truncate the precision of. These values are modified in place. + prec + The fractional precision required. + + Returns + ------- + val + The modified array. This shares the same underlying memory as the input. + """ + cdef Py_ssize_t n = val.shape[0] + cdef Py_ssize_t i = 0 + + for i in prange(n, nogil=True): + val[i] = _bit_truncate_double(val[i], prec * val[i]) + + return np.asarray(val, dtype=np.float64) + + +@cython.boundscheck(False) +@cython.wraparound(False) +def bit_truncate_max_complex(complex128[:, :] val, float prec, float prec_max_row): """Truncate using a relative per element and per the maximum of the last dimension. This scheme allows elements to be truncated based on their own value and a @@ -155,7 +302,7 @@ def bit_truncate_max_complex(complex128[:, ::1] val, float prec, float prec_max_ vr = val[i, j].real vi = val[i, j].imag - val[i, j].real = bit_truncate_float(vr, abs_prec) - val[i, j].imag = bit_truncate_float(vi, abs_prec) + val[i, j].real = _bit_truncate_float(vr, abs_prec) + val[i, j].imag = _bit_truncate_float(vi, abs_prec) - return np.asarray(val) \ No newline at end of file + return np.asarray(val) From 0b6210129f642d69e37d162eb6c5879b8ea057fb Mon Sep 17 00:00:00 2001 From: Tristan Pinsonneault-Marotte Date: Mon, 7 Mar 2022 14:46:50 -0800 Subject: [PATCH 4/7] test(truncate): Add special cases. --- caput/tests/test_truncate.py | 50 +++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/caput/tests/test_truncate.py b/caput/tests/test_truncate.py index 007f1aae..e1c4f7f1 100644 --- a/caput/tests/test_truncate.py +++ b/caput/tests/test_truncate.py @@ -5,25 +5,73 @@ def test_bit_truncate(): assert truncate.bit_truncate_int(129, 1) == 128 + assert truncate.bit_truncate_int(-129, 1) == -128 + assert truncate.bit_truncate_int(1, 1) == 0 assert truncate.bit_truncate_long(129, 1) == 128 + assert truncate.bit_truncate_long(-129, 1) == -128 assert truncate.bit_truncate_long(576460752303423489, 1) == 576460752303423488 assert ( truncate.bit_truncate_long(4520628863461491, 140737488355328) == 4503599627370496 ) + assert truncate.bit_truncate_long(1, 1) == 0 assert truncate.bit_truncate_int(54321, 0) == 54321 assert truncate.bit_truncate_long(576460752303423489, 0) == 576460752303423489 + # special cases + assert truncate.bit_truncate_int(129, 0) == 129 + assert truncate.bit_truncate_int(0, 1) == 0 + assert truncate.bit_truncate_int(129, -1) == 0 + assert truncate.bit_truncate_long(129, 0) == 129 + assert truncate.bit_truncate_long(0, 1) == 0 + assert truncate.bit_truncate_long(129, -1) == 0 + def test_truncate_float(): assert truncate.bit_truncate_float(32.121, 1) == 32 - # fails assert truncate.bit_truncate_float(float(0.010101), 0) == float(0.010101) + assert truncate.bit_truncate_float(-32.121, 1) == -32 + assert truncate.bit_truncate_float(32.125, 0) == 32.125 + assert truncate.bit_truncate_float(1, 1) == 0 + + assert truncate.bit_truncate_float(1 + 1 / 1024, 1 / 2048) == 1 + 1 / 1024 + assert ( + truncate.bit_truncate_float(1 + 1 / 1024 + 1 / 2048, 1 / 2048) == 1 + 2 / 1024 + ) + assert truncate.bit_truncate_double(1 + 1 / 1024, 1 / 2048) == 1 + 1 / 1024 + assert ( + truncate.bit_truncate_double(1 + 1 / 1024 + 1 / 2048, 1 / 2048) == 1 + 2 / 1024 + ) assert truncate.bit_truncate_double(32.121, 1) == 32 + assert truncate.bit_truncate_double(-32.121, 1) == -32 + assert truncate.bit_truncate_double(32.121, 0) == 32.121 + assert truncate.bit_truncate_double(0.9191919191, 0.001) == 0.919921875 assert truncate.bit_truncate_double(0.9191919191, 0) == 0.9191919191 + assert truncate.bit_truncate_double(0.010101, 0) == 0.010101 + assert truncate.bit_truncate_double(1, 1) == 0 + + # special cases + assert truncate.bit_truncate_float(32.121, -1) == 0 + assert truncate.bit_truncate_double(32.121, -1) == 0 + assert truncate.bit_truncate_float(32.121, np.inf) == 0 + assert truncate.bit_truncate_double(32.121, np.inf) == 0 + assert truncate.bit_truncate_float(32.121, np.nan) == 0 + assert truncate.bit_truncate_double(32.121, np.nan) == 0 + assert truncate.bit_truncate_float(np.inf, 1) == np.inf + assert truncate.bit_truncate_double(np.inf, 1) == np.inf + assert np.isnan(truncate.bit_truncate_float(np.nan, 1)) + assert np.isnan(truncate.bit_truncate_double(np.nan, 1)) + assert truncate.bit_truncate_float(np.inf, np.inf) == 0 + assert truncate.bit_truncate_double(np.inf, np.inf) == 0 + assert truncate.bit_truncate_float(np.inf, np.nan) == 0 + assert truncate.bit_truncate_double(np.inf, np.nan) == 0 + assert truncate.bit_truncate_float(np.nan, np.nan) == 0 + assert truncate.bit_truncate_double(np.nan, np.nan) == 0 + assert truncate.bit_truncate_float(np.nan, np.inf) == 0 + assert truncate.bit_truncate_double(np.nan, np.inf) == 0 def test_truncate_array(): From 5ffd5696cccb88affde5ab6a3f21469d3565e216 Mon Sep 17 00:00:00 2001 From: Richard Shaw Date: Fri, 4 Mar 2022 18:20:35 -0800 Subject: [PATCH 5/7] fix(profile): workaround missing metrics on macOS --- caput/profile.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/caput/profile.py b/caput/profile.py index 63276fa2..0026d27e 100644 --- a/caput/profile.py +++ b/caput/profile.py @@ -312,7 +312,7 @@ def start(self): self.cpu_percent() self._start_memory = self.memory_full_info().uss if psutil.MACOS: - self._start_memory = psutil.disk_io_counters() + self._start_disk_io = psutil.disk_io_counters() else: self._start_disk_io = self.io_counters() @@ -426,14 +426,14 @@ def bytes2human(num): cpu_times["system"], cpu_times["children_user"], cpu_times["children_system"], - cpu_times["iowait"], + cpu_times.get("iowait", "-"), cpu_percent, disk_io["read_count"], disk_io["write_count"], disk_io["read_bytes"], disk_io["write_bytes"], - disk_io["read_chars"], - disk_io["write_chars"], + disk_io.get("read_chars", "-"), + disk_io.get("write_chars", "-"), memory, available_memory, used_memory, @@ -443,6 +443,8 @@ def bytes2human(num): @property def cpu_count(self): """Number of cores available to this process.""" + if psutil.MACOS: + return psutil.cpu_count() return len(self.cpu_affinity()) @property From 5d67429192e591e93d269632bfb8391336501cfc Mon Sep 17 00:00:00 2001 From: Richard Shaw Date: Fri, 4 Mar 2022 18:58:27 -0800 Subject: [PATCH 6/7] ci: update workflow and use bitshuffle binary wheels --- .github/workflows/main.yml | 15 +++++++++------ setup.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7cc2a13d..9a44b00e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -15,17 +15,18 @@ jobs: - uses: actions/checkout@v2 - name: Set up Python 3.9 - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: 3.9 - name: Install apt dependencies run: | + sudo apt-get update sudo apt-get install -y libopenmpi-dev openmpi-bin libhdf5-serial-dev - name: Install pip dependencies run: | - pip install pylint==2.7.0 pylint-ignore flake8 pytest black mpi4py pyinstrument pytest-lazy-fixture + pip install pylint==2.7.0 pylint-ignore flake8 pytest black mpi4py pyinstrument psutil pytest-lazy-fixture pip install -r requirements.txt python setup.py develop pip install .[compression] @@ -48,19 +49,20 @@ jobs: - name: Install apt dependencies run: | + sudo apt-get update sudo apt-get install -y libopenmpi-dev openmpi-bin libhdf5-serial-dev - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - name: Install pip dependencies run: | - pip install --no-binary=h5py h5py + pip install h5py pip install -r requirements.txt pip install zarr==2.8.1 - pip install mpi4py numcodecs==0.7.3 bitshuffle@git+https://github.com/kiyo-masui/bitshuffle.git psutil + pip install mpi4py numcodecs==0.7.3 bitshuffle pip install pytest pytest-lazy-fixture python setup.py develop @@ -82,12 +84,13 @@ jobs: - uses: actions/checkout@v2 - name: Set up Python 3.9 - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: 3.9 - name: Install apt dependencies run: | + sudo apt-get update sudo apt-get install -y libhdf5-serial-dev - name: Install pip dependencies diff --git a/setup.py b/setup.py index d5ef0c44..5c507d52 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ extras_require={ "mpi": ["mpi4py>=1.3"], "compression": [ - "bitshuffle @ git+https://github.com/kiyo-masui/bitshuffle.git", + "bitshuffle", "numcodecs==0.7.3", "zarr==2.8.1", ], From 75987f102e77c0acca46bd98f24853faebb6ba43 Mon Sep 17 00:00:00 2001 From: Richard Shaw Date: Mon, 30 May 2022 11:17:24 -0700 Subject: [PATCH 7/7] doc: restrict sphinx version to workaround bug in v5.0 --- doc/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/requirements.txt b/doc/requirements.txt index 23ae6ce9..e3de9330 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,4 +1,4 @@ -Sphinx >= 4.0 +Sphinx >= 4.0, < 5.0 sphinx_rtd_theme # funcsigs required by mock, which apparently does not have its dependencies # setup correctly.