Skip to content

Commit

Permalink
feat(memh5): consistently guess file format from name
Browse files Browse the repository at this point in the history
  • Loading branch information
jrs65 committed Mar 12, 2022
1 parent 2622060 commit 939a926
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
31 changes: 19 additions & 12 deletions caput/memh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,12 +479,13 @@ def from_file(
selections=None,
convert_dataset_strings=False,
convert_attribute_strings=True,
file_format=fileformats.HDF5,
file_format=None,
**kwargs,
):
"""Create a new instance by copying from a file group.
Any keyword arguments are passed on to the constructor for `h5py.File` or `zarr.File`.
Any keyword arguments are passed on to the constructor for `h5py.File` or
`zarr.File`.
Parameters
----------
Expand All @@ -505,8 +506,8 @@ def from_file(
Try and convert attribute string types to unicode. Default is `True`.
convert_dataset_strings : bool, optional
Try and convert dataset string types to unicode. Default is `False`.
file_format : `fileformats.FileFormat`
File format to use. Default `fileformats.HDF5`.
file_format : `fileformats.FileFormat`, optional
File format to use. Default is `None`, i.e. guess from the name.
Returns
-------
Expand All @@ -517,6 +518,9 @@ def from_file(
if comm is None:
comm = mpiutil.world

if file_format is None:
file_format = fileformats.guess_file_format(filename)

if comm is None:
if distributed:
warnings.warn(
Expand Down Expand Up @@ -595,7 +599,7 @@ def to_file(
hints=True,
convert_attribute_strings=True,
convert_dataset_strings=False,
file_format=fileformats.HDF5,
file_format=None,
**kwargs,
):
"""Replicate object on disk in an hdf5 or zarr file.
Expand All @@ -614,9 +618,12 @@ def to_file(
understands. Default is `True`.
convert_dataset_strings : bool, optional
Try and convert dataset string types to bytestrings. Default is `False`.
file_format : `fileformats.FileFormat`
File format to use. Default `fileformats.HDF5`.
file_format : `fileformats.FileFormat`, optional
File format to use. Default is `None`, i.e. guess from the name.
"""
if file_format is None:
file_format = fileformats.guess_file_format(filename)

if not self.distributed:
with file_format.open(filename, mode, **kwargs) as f:
deep_group_copy(
Expand All @@ -627,9 +634,6 @@ def to_file(
file_format=file_format,
)
else:
if file_format is None:
file_format = fileformats.guess_file_format(filename)

_distributed_group_to_file(
self,
filename,
Expand Down Expand Up @@ -1665,7 +1669,7 @@ def from_file(
detect_subclass=True,
convert_attribute_strings=None,
convert_dataset_strings=None,
file_format=fileformats.HDF5,
file_format=None,
**kwargs,
):
"""Create data object from analysis hdf5 file, store in memory or on disk.
Expand Down Expand Up @@ -1704,11 +1708,14 @@ def from_file(
Axis selections can be given to only read a subset of the containers. A
slice can be given, or a list of specific array indices for that axis.
file_format : `fileformats.FileFormat`
File format to use. Default `fileformats.HDF5`.
File format to use. Default is `None`, i.e. guess from file name.
**kwargs : any other arguments
Any additional keyword arguments are passed to :class:`h5py.File`'s
constructor if *file_* is a filename and silently ignored otherwise.
"""
if file_format is None and not is_group(file_):
file_format = fileformats.guess_file_format(file_)

if file_format == fileformats.Zarr and not zarr_available:
raise RuntimeError("Unable to read zarr file, please install zarr.")

Expand Down
8 changes: 3 additions & 5 deletions caput/tests/test_memh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,11 @@ def test_to_from_file(test_file, file_open_function, file_format):
with file_open_function(test_file, "r") as f:
assertGroupsEqual(f, m)

m.to_file(
test_file + ".new",
file_format=file_format,
)
new_name = f"new.{test_file}"
m.to_file(new_name, file_format=file_format)

# Check that written file has same structure
with file_open_function(test_file + ".new", "r") as f:
with file_open_function(new_name, "r") as f:
assertGroupsEqual(f, m)


Expand Down

0 comments on commit 939a926

Please sign in to comment.