Skip to content

Commit

Permalink
feat(containers): add flag to allow filtering distributed axis
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Sep 7, 2022
1 parent 2d3a423 commit 78f6c14
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion draco/core/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2673,6 +2673,7 @@ def copy_datasets_filter(
axis: str,
selection: Union[np.ndarray, list, slice],
exclude_axes: List[str] = None,
allow_distributed: bool = False,
):
"""Copy datasets while filtering a given axis.
Expand All @@ -2689,6 +2690,10 @@ def copy_datasets_filter(
exclude_axes
An optional set of axes that if a dataset contains one means it will
not be copied.
allow_distributed, optional
Allow the filtered axis to be the distributed axis. This is ONLY
valid if filtering is occuring on the local rank only, and mainly
exists for compatibility
"""
exclude_axes_set = set(exclude_axes) if exclude_axes else set()

Expand Down Expand Up @@ -2717,7 +2722,7 @@ def copy_datasets_filter(

if isinstance(item, memh5.MemDatasetDistributed):

if item.distributed_axis == axis_ind:
if (item.distributed_axis == axis_ind) and not allow_distributed:
raise RuntimeError(
f"Cannot redistristribute dataset={item.name} along "
f"axis={axis_ind} as it is distributed."
Expand Down

0 comments on commit 78f6c14

Please sign in to comment.