Skip to content

Commit

Permalink
feat(memh5): modify deep_group_copy to optionally make a true deep copy.
Browse files Browse the repository at this point in the history
Provides deep_group_copy with arguments to actually deep copy datasets,
with an option to shallow copy 'shared' datasets. Also, allow
distributed datasets to be copied only in the memory -> memory case.
  • Loading branch information
ljgray committed Mar 14, 2023
1 parent 1ba2ec7 commit dfb48b1
Showing 1 changed file with 74 additions and 7 deletions.
81 changes: 74 additions & 7 deletions caput/memh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
- :py:meth:`deep_group_copy`
"""

from __future__ import annotations
from typing import Any

from collections.abc import Mapping
import datetime
import warnings
Expand Down Expand Up @@ -2425,6 +2428,8 @@ def deep_group_copy(
file_format=fileformats.HDF5,
skip_distributed=False,
postprocess=None,
deep_copy_dsets=False,
shared=[],
):
"""
Copy full data tree from one group to another.
Expand Down Expand Up @@ -2465,6 +2470,12 @@ def deep_group_copy(
postprocess : function, optional
A function that takes is called on each node, with the source and destination
entries, and can modify either.
deep_copy_dsets : bool, optional
Explicitly deep copy all datasets. This will only alter behaviour when copying
from memory to memory. XXX: enabling this in places where it is not currently
enabled could break legacy code, so be very careful
shared : list, optional
List of datasets to share, if `deep_copy_dsets` is True. Otherwise, no effect.
Returns
-------
Expand All @@ -2482,7 +2493,7 @@ def deep_group_copy(

# Prepare a dataset for writing out, applying selections and transforming any
# datatypes
# Returns: (dtype, shape, data_to_write)
# Returns: dict(dtype, shape, data_to_write)
def _prepare_dataset(dset):
# Look for a selection for this dataset (also try without the leading "/")
try:
Expand All @@ -2508,7 +2519,7 @@ def _prepare_dataset(dset):
else:
# If we get here, we should create the dataset, but not write out any data into it (i.e. return None)
distributed_dset_names.append(dset.name)
return dset.dtype, dset.shape, None
return {"dtype": dset.dtype, "shape": dset.shape, "data": None}

# Extract the data for the selection
data = dset[selection]
Expand All @@ -2534,7 +2545,14 @@ def _prepare_dataset(dset):
# needed until fixed: https://github.com/mpi4py/mpi4py/issues/177
data = ensure_native_byteorder(data)

return data.dtype, data.shape, data
dset_args = {"dtype": data.dtype, "shape": data.shape, "data": data}
# If we're copying memory to memory we can allow distributed datasets
if not to_file and isinstance(dset, MemDatasetDistributed):
dset_args.update(
{"distributed": True, "distributed_axis": dset.distributed_axis}
)

return dset_args

# get compression options/chunking for this dataset
# Returns dict of compression and chunking arguments for create_dataset
Expand Down Expand Up @@ -2581,14 +2599,16 @@ def _prepare_compression_args(dset):
g2.create_group(key)
stack += [entry[k] for k in sorted(entry, reverse=True)]
else: # Is a dataset
dtype, shape, data = _prepare_dataset(entry)
dset_args = _prepare_dataset(entry)
compression_kwargs = _prepare_compression_args(entry)

if deep_copy_dsets and key not in shared:
# Make a deep copy of the dataset
dset_args["data"] = deep_copy_dataset(dset_args["data"])

g2.create_dataset(
key,
shape=shape,
dtype=dtype,
data=data,
**dset_args,
**compression_kwargs,
)

Expand All @@ -2601,6 +2621,50 @@ def _prepare_compression_args(dset):
return distributed_dset_names if skip_distributed else None


def deep_copy_dataset(dset: Any, order: str = "A") -> Any:
"""Return a deep copy of a dataset.
If the dataset is a ndarray or subclass, the memory
layout can be set.
Parameters
----------
dset
Dataset to deep copy
order
Controls the memory layout of the copy, for any dataset which
takes this parameter (np.ndarray and subclasses)
Returns
-------
dset_copy
A deep copy of the dataset
"""
if isinstance(dset, np.ndarray):
# Set the order
dset_copy = dset.copy(order=order)

_o = np.dtype(object)
_d = np.dtype(dset_copy.dtype)
# `ndarray.copy` won't create a deep copy of the
# array, so this has to be done if the array contains
# some mutable python objects
if _d is _o:
dset_copy = deepcopy(dset_copy)

elif _d.names is not None:
# This is a structured dtype, so check each field
for name in _d.names:
if _d.fields[name][0] is _o:
dset_copy[name] = deepcopy(dset_copy[name])

else:
# Deep copy whatever object was provided
dset_copy = deepcopy(dset)

return dset_copy


def format_abs_path(path):
"""Return absolute path string, formated without any extra '/'s."""
if not posixpath.isabs(path):
Expand Down Expand Up @@ -2822,6 +2886,9 @@ def _copy_from_file(h5group, memgroup, selections=None):
return group


# Some extra functions for types. Should maybe move elsewhere


def bytes_to_unicode(s):
"""Ensure that a string (or collection of) are unicode.
Expand Down

0 comments on commit dfb48b1

Please sign in to comment.