Skip to content

Commit

Permalink
feat(memh5): add '.copy' method to MemDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Mar 14, 2023
1 parent dfb48b1 commit 1fb0b6d
Showing 1 changed file with 45 additions and 6 deletions.
51 changes: 45 additions & 6 deletions caput/memh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@

import numpy as np
import h5py
from copy import deepcopy

from . import fileformats
from . import mpiutil
Expand Down Expand Up @@ -984,6 +985,49 @@ def __init__(self, **kwargs):
def _group_class(self):
return MemGroup

def copy(self, order: str = "A", shallow: bool = False) -> MemDataset:
"""Create a new MemDataset from an existing one.
This creates a deep copy by default.
Parameters
----------
order
Memory layout of copied data. See
https://numpy.org/doc/stable/reference/generated/numpy.copy.html
shallow
True if this should be a shallow copy
Returns
-------
new_data
deep copy of this dataset
"""
new_data = self.__class__.__new__(self.__class__)
super(MemDataset, new_data).__init__(
name=self.name, storage_root=self._storage_root
)

_copy = deepcopy if not shallow else lambda x: x
# Call the properties rather than the underlying values so that an error
# is properly raised if they are not implemented. Blindly use deepcopy as
# we don't make assumptions about immutability
new_data._chunks = _copy(self.chunks)
new_data._compression = _copy(self.compression)
new_data._compression_opts = _copy(self.compression_opts)
new_data._attrs = _copy(self._attrs)

if shallow:
new_data._data = self._data
else:
new_data._data = deep_copy_dataset(self._data, order=order)

return new_data

def __deepcopy__(self, memo, /) -> MemDataset:
"""Called when copy.deepcopy is called on this class"""
return self.copy()

def view(self):
cls = self.__class__
out = cls.__new__(cls)
Expand Down Expand Up @@ -2269,12 +2313,7 @@ def redistribute(self, dist_axis):
def attrs2dict(attrs):
"""Safely copy an h5py attributes object to a dictionary."""

out = {}
for key, value in attrs.items():
if isinstance(value, np.ndarray):
value = value.copy()
out[key] = value
return out
return {k: deepcopy(v) for k, v in attrs.items()}


def is_group(obj):
Expand Down

0 comments on commit 1fb0b6d

Please sign in to comment.