diff --git a/caput/memh5.py b/caput/memh5.py index d13e5e5e..94f2f8b9 100644 --- a/caput/memh5.py +++ b/caput/memh5.py @@ -461,6 +461,7 @@ def from_hdf5(cls, filename, distributed=False, hints=True, comm=None, **kwargs) distributed = False if not distributed or not hints: + kwargs["mode"] = "r" with h5py.File(filename, **kwargs) as f: self = cls(distributed=distributed, comm=comm) deep_group_copy(f, self) @@ -469,7 +470,7 @@ def from_hdf5(cls, filename, distributed=False, hints=True, comm=None, **kwargs) return self - def to_hdf5(self, filename, hints=True, **kwargs): + def to_hdf5(self, filename, mode="w", hints=True, **kwargs): """Replicate object on disk in an hdf5 file. Any keyword arguments are passed on to the constructor for `h5py.File`. @@ -484,13 +485,13 @@ def to_hdf5(self, filename, hints=True, **kwargs): """ if not self.distributed: - with h5py.File(filename, **kwargs) as f: + with h5py.File(filename, mode, **kwargs) as f: deep_group_copy(self, f) else: if h5py.get_config().mpi: - _distributed_group_to_hdf5_parallel(self, filename, **kwargs) + _distributed_group_to_hdf5_parallel(self, filename, mode, **kwargs) else: - _distributed_group_to_hdf5_serial(self, filename, **kwargs) + _distributed_group_to_hdf5_serial(self, filename, mode, **kwargs) def create_group(self, name): """Create a group within the storage tree.""" @@ -1301,7 +1302,7 @@ def __init__(self, data_group=None, distributed=False, comm=None): # Otherwise, presume it is an HDF5 Group-like object (which includes # MemGroup and h5py.Group). else: - data_group, toclose = get_h5py_File(data_group) + data_group, toclose = get_h5py_File(data_group, mode="a") if distributed and isinstance(data_group, h5py.Group): raise ValueError( @@ -1478,8 +1479,11 @@ def from_file( if isinstance(file_, h5py.Group): file_ = file_.filename + if "mode" in kwargs: + del kwargs["mode"] + data = MemGroup.from_hdf5( - file_, distributed=distributed, comm=comm, mode="r", **kwargs + file_, distributed=distributed, comm=comm, **kwargs ) toclose = False else: @@ -1488,6 +1492,7 @@ def from_file( data = file_ toclose = False else: + kwargs.setdefault("mode", "a") data = h5py.File(file_, **kwargs) toclose = True @@ -2059,7 +2064,7 @@ def format_abs_path(path): return out -def _distributed_group_to_hdf5_serial(group, fname, hints=True, **kwargs): +def _distributed_group_to_hdf5_serial(group, fname, mode, hints=True, **kwargs): """Private routine to copy full data tree from distributed memh5 object into an HDF5 file. @@ -2073,18 +2078,13 @@ def _distributed_group_to_hdf5_serial(group, fname, hints=True, **kwargs): comm = group.comm - # Create a copy of the kwargs with no mode argument so that we can override it - kwargs_nomode = kwargs.copy() - if "mode" in kwargs: - del kwargs_nomode["mode"] - # Create group (or file) if comm.rank == 0: # If this is the root group, create the file and copy the file level # attrs if group.name == "/": - with h5py.File(fname, "w", **kwargs) as f: + with h5py.File(fname, mode, **kwargs) as f: copyattrs(group.attrs, f.attrs) if hints: @@ -2098,7 +2098,8 @@ def _distributed_group_to_hdf5_serial(group, fname, hints=True, **kwargs): comm.Barrier() - # Write out groups and distributed datasets, these operations must be done collectively + # Write out groups and distributed datasets, these operations must be done + # collectively # Sort to ensure insertion order is identical for key in sorted(group): @@ -2106,7 +2107,7 @@ def _distributed_group_to_hdf5_serial(group, fname, hints=True, **kwargs): # Groups are written out by recursing if is_group(entry): - _distributed_group_to_hdf5_serial(entry, fname, **kwargs) + _distributed_group_to_hdf5_serial(entry, fname, mode, **kwargs) # Write out distributed datasets (only the data, the attributes are written below) elif isinstance(entry, MemDatasetDistributed): @@ -2126,7 +2127,7 @@ def _distributed_group_to_hdf5_serial(group, fname, hints=True, **kwargs): # Write out common datasets, and the attributes on distributed datasets if comm.rank == 0: - with h5py.File(fname, "r+", **kwargs_nomode) as f: + with h5py.File(fname, "r+", **kwargs) as f: for key, entry in group.items(): @@ -2174,7 +2175,7 @@ def _distributed_group_to_hdf5_serial(group, fname, hints=True, **kwargs): comm.Barrier() -def _distributed_group_to_hdf5_parallel(group, fname, hints=True, **kwargs): +def _distributed_group_to_hdf5_parallel(group, fname, mode, hints=True, **kwargs): """Private routine to copy full data tree from distributed memh5 object into an HDF5 file. This version paralellizes all IO.""" @@ -2250,7 +2251,6 @@ def _copy_to_file(memgroup, h5group): copyattrs(item.attrs, dset.attrs) # Open file on all ranks - mode = kwargs.get("mode", "w") with misc.open_h5py_mpi(fname, mode, comm=group.comm) as f: if not f.is_mpi: raise RuntimeError("Could not create file %s in MPI mode" % fname) diff --git a/caput/tests/test_memh5.py b/caput/tests/test_memh5.py index 966e0820..8ccb7066 100644 --- a/caput/tests/test_memh5.py +++ b/caput/tests/test_memh5.py @@ -87,7 +87,7 @@ class TestH5Files(unittest.TestCase): fname = "tmp_test_memh5.h5" def setUp(self): - with h5py.File(self.fname) as f: + with h5py.File(self.fname, "w") as f: l1 = f.create_group("level1") l2 = l1.create_group("level2") d1 = l1.create_dataset("large", data=np.arange(100)) @@ -118,9 +118,8 @@ def assertAttrsEqual(self, a, b): self.assertEqual(this_a, this_b) def test_h5_sanity(self): - f = h5py.File(self.fname) - self.assertGroupsEqual(f, f) - f.close() + with h5py.File(self.fname, "r") as f: + self.assertGroupsEqual(f, f) def test_to_from_hdf5(self): m = memh5.MemGroup.from_hdf5(self.fname) @@ -189,10 +188,10 @@ def test_io(self): # self.assertIsInstance(tsc3['dset'].parent, TempSubClass) tsc3.close() - with memh5.MemDiskGroup.from_file(self.fname, ondisk=True) as tsc4: + with memh5.MemDiskGroup.from_file(self.fname, mode="r", ondisk=True) as tsc4: self.assertRaises(IOError, h5py.File, self.fname, "w") - with memh5.MemDiskGroup.from_file(self.fname, ondisk=False) as tsc4: + with memh5.MemDiskGroup.from_file(self.fname, mode="r", ondisk=False) as tsc4: f = h5py.File(self.fname, "w") f.close()