Skip to content

Commit

Permalink
feat(transform): add task to perform downselection on axes
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed May 2, 2023
1 parent b27677e commit bff7776
Showing 1 changed file with 131 additions and 0 deletions.
131 changes: 131 additions & 0 deletions draco/analysis/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,137 @@ def process_finish(self) -> Union[containers.SiderealStream, containers.RingMap]
return data


class Downselect(task.SingleTask):
"""Apply axis downselections to a container.
Apply slice or `np.take` operations across multiple axes of a container.
The datasets to apply the selection to can be specified, and any datasets
not included will not be initialized in the output container.
If a dataset is distributed, there must be at least one axis not included
in the downselection.
Attributes
----------
selections : dict
Dict formatted as {<ax_name>:<selection>,} for any axes
that should be downselected before applying the reduction.
Selections can be either slices or a list/tuple of indices.
datasets : list, optional
List of dataset names to select. By default, all datasets will be included,
and any which include the specified axes will be sliced.
distribute_over : str, optional
Optionally prefer to redistribute over a specific axis. This should be
the name of the axis. By default, the best possible axis will be selected
based on the selections to be applied.
"""

selections = config.Property(proptype=dict)
datasets = config.Property(proptype=list, default=[])
distribute_over = config.Property(proptype=str, default=None)

def process(self, data: containers.ContainerBase) -> containers.ContainerBase:
"""Apply downselections to the container.
Parameters
----------
data
Container to process
Returns
-------
out
Container of same type as the input with specific axis selections.
Any datasets not included in the selections will not be initialized.
"""
# Order axes based on decreasing length to help decide the best
# redistribution options
ax_priority = sorted(
data.index_map.keys(), key=lambda x: len(data.index_map[x]), reverse=True
)

def _select_dist_axis(options: list | tuple, current_ax: str) -> tuple[int]:
"""Select the best distributed axis for selection."""
# If possible, use the user-specified axis
if self.distribute_over in options and self.distribute_over != current_ax:
return options.index(self.distribute_over)

new_ax_id = None

# first, check if we can just use the current axis
if current_ax not in self.selections:
new_ax_id = options.index(current_ax)
else:
for ax in ax_priority:
if ax not in options:
continue
if ax not in self.selections:
new_ax_id = options.index(ax)
break

# If we couldn't find a valid axis, throw an exception
if new_ax_id is None:
raise ValueError("Could not find an axis to distribute across.")

return new_ax_id

def _apply_sel(arr, sel, axis):
"""Apply a selection to a single axis of an array."""
if type(sel) is slice:
sel = (slice(None),) * axis + (sel,)
return arr[sel]
elif type(sel) in {list, tuple}:
return np.take(arr, sel, axis=axis)

# Figure out the axes for the new container
# Apply the downselections to each axis index_map
output_axes = {
ax: _apply_sel(data.index_map[ax], sel, 0)
for ax, sel in self.selections.items()
}

# Create the output container without initializing any datasets.
# Add some extra metadata about the selections made
out = data.__class__(
axes_from=data, attrs_from=data, skip_datasets=True, **output_axes
)
out.attrs["downselections"] = self.selections
out.attrs["sliced_datasets"] = np.array(self.datasets)

for name in self.datasets:
# Get the dataset
ds = data.datasets[name]
# Initialize the dataset in the output container
out.add_dataset(name)

ds_axes = list(ds.attrs["axis"])
arr = ds[:]

if isinstance(arr, mpiarray.MPIArray):
# Select the best axes to redistribute to
# _select_axes takes the name of the current axis, not the index
original_ax_id = arr.axis
new_ax_id = _select_dist_axis(ds_axes, ds_axes[original_ax_id])
# Redistribute to the axis used for downselection
arr = arr.redistribute(new_ax_id)

# Apply downselections to the dataset
for ax, sel in self.selections.items():
try:
ax_ind = ds_axes.index(ax)
except ValueError:
continue
arr = _apply_sel(arr, sel, ax_ind)

if isinstance(arr, mpiarray.MPIArray):
# Distribute to the axis used for reduction
arr = arr.redistribute(original_ax_id)

out[name][:] = arr

return out


class Reduce(task.SingleTask):
"""Apply a reduction operation across specific axes and datasets.
Expand Down

0 comments on commit bff7776

Please sign in to comment.