Skip to content

Commit

Permalink
feat(transform): rework how weights are handled in reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed May 1, 2023
1 parent 51f8842 commit 9fd3e58
Showing 1 changed file with 148 additions and 101 deletions.
249 changes: 148 additions & 101 deletions draco/analysis/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,28 +1286,38 @@ class Reduce(task.SingleTask):
Axis names to apply the reduction to
datasets : list
Dataset names to reduce.
selections : dict, optional
Dictionary formatted as {<ax_name>:<selection>} for any axes
selections : list, optional
List formatted as [[<ax_name>,<selection>],] for any axes
that should be downselected before applying the reduction.
Selections can be either slices or a list/tuple of indices.
op : str
Reduction operation to apply. Must be the name of a numpy
attribute, ex. "var". Default is "mean".
generate_mask : bool, optional
If true, generate a boolean mask to remove samples that are flagged.
When this is set, reduction can only be applied to datasets with
the same axes as the weights.
apply_mask : bool, optional
Whether to apply the weight mask to the datasets before reduction.
This won't do anything unless `generate_mask` is True.
weight_op : str
Reduction operation to apply to the weights. Must be the name of a numpy
attribute, ex. "var". When using a boolean operation, such as "all" or "any",
the operation is applied to the boolean array (weights > 0.0). Default is "any".
apply_weight : list, optional
Apply the weights to any dataset in this list by multiplying by the weights.
The downselected weights must be broadcastable to these datasets.
mask_array : bool, optional
If true, use a numpy masked array when reducing. The mask is the boolean array
~(weights > 0.0). This can be used in conjunction with `apply_weight`. Default
is False.
"""

axes = config.Property(proptype=list)
datasets = config.Property(proptype=list)
selections = config.Property(proptype=dict, default={})
op = config.Property(proptype=str, default="mean")
generate_mask = config.Property(proptype=bool, default=True)
apply_mask = config.Property(proptype=bool, default=False)
weight_op = config.Property(proptype=str, default="any")
apply_weight = config.Property(proptype=list, default=[])
mask_array = config.Property(proptype=bool, default=False)

def setup(self):
"""Get the actual reduction ops and make sure they exist."""
self._op_func = getattr(np, self.op)
self._weight_op_func = getattr(np, self.weight_op)

def process(self, data: containers.ContainerBase) -> containers.ContainerBase:
"""Downselect and apply the reduction operation to the data.
Expand All @@ -1321,19 +1331,16 @@ def process(self, data: containers.ContainerBase) -> containers.ContainerBase:
-------
out
Dataset of same type as input with axes reduced. Any datasets
which are not included in the reduction list are set to zero,
except for weights.
which are not included in the reduction list will not be initialized,
other than weights.
"""
# Get the reduction operation we want to use
op = getattr(np, self.op)

# Order axes based on decreasing length to help decide the best
# redistribution options
ax_priority = sorted(
data.index_map.keys(), key=lambda x: len(data.index_map[x]), reverse=True
)

def _select_axes(options: list, current_ax: str) -> tuple[int, int]:
def _select_dist_axes(options: list | tuple, current_ax: str) -> tuple[int]:
"""Select the best distributed axes for selection and reduction."""
sel_ax_id, red_ax_id = None, None

Expand All @@ -1350,7 +1357,7 @@ def _select_axes(options: list, current_ax: str) -> tuple[int, int]:
for ax in ax_priority:
if ax not in options:
continue
if ax not in self.selections.keys():
if ax not in self.selections:
# Set the first axis to distribute across
sel_ax_id = options.index(ax)
# If possible, use this axis again
Expand Down Expand Up @@ -1384,6 +1391,54 @@ def _apply_sel(arr, sel, axis):
elif type(sel) in {list, tuple}:
return np.take(arr, sel, axis=axis)

def _select_red_axes(ds_axes: list | tuple) -> tuple[int]:
"""Get the indices of the axes that we want to reduce over.
The dataset must contain all the axes that we want to apply
the reduction over.
"""
apply_over = []
for ax in self.axes:
try:
apply_over.append(ds_axes.index(ax))
except ValueError as e:
raise ValueError(f"Axis {ax} not found in dataset {name}.") from e

return tuple(apply_over)

def _downselect(ds):
"""Apply downselection to a dataset, redistributing as required.
If this is a distributed dataset, the output array will be distributed
to the best possible axis for the subsequent reduction.
"""
ds_axes = ds.attrs["axis"]
arr = ds[:]

if ds.distributed:
# Select the best axes to redistribute to
# _select_axes takes the name of the current axis, not the index
original_ax_id = ds.distributed_axis
sel_ax_id, red_ax_id = _select_dist_axes(
ds_axes, ds_axes[original_ax_id]
)
# Redistribute to the axis used for downselection
arr = arr.redistribute(sel_ax_id)

# Apply downselections to the dataset
for ax, sel in self.selections.items():
try:
ax_ind = ds_axes.index(ax)
except ValueError:
continue
arr = _apply_sel(arr, sel, ax_ind)

if ds.distributed:
# Distribute to the axis used for reduction
arr = arr.redistribute(red_ax_id)

return arr

# Figure out the axes for the new container
# Apply the downselections to each axis index_map
output_axes = {
Expand All @@ -1403,115 +1458,107 @@ def _apply_sel(arr, sel, axis):
)
out.attrs["reduced"] = True
out.attrs["reduction_op"] = self.op
out.attrs["weight_reduction_op"] = self.weight_op
out.attrs["reduction_axes"] = np.array(self.axes)
out.attrs["downselections"] = self.selections
out.attrs["reduced_datasets"] = np.array(self.datasets)

# Initialize the weights dataset if possible
# Initialize the weight dataset
if "weight" in data.datasets:
out.add_dataset("weight")
elif "vis_weight" in data.datasets:
out.add_dataset("vis_weight")
elif self.generate_mask or self.apply_mask:
raise ValueError("Weight mask was requested but no weight dataset found.")

# Iterate over the datasets and apply operation
for name, ds in data.datasets.items():
if name not in self.datasets:
# We don't care to preserve this dataset, or if this is the
# weight dataset it will already have been initialized
continue

# Initialize the dataset in the output container
out.add_dataset(name)

# Get the axes for this dataset
ds_axes = list(ds.attrs["axis"])
# Set up the new weights
if hasattr(out, "weight"):
# Get the current distributed axis, if it is distributed
if data.weight.distributed:
original_ax_id = data.weight.distributed_axis

weight_axes = data.weight.attrs["axis"]
# Get the axes to reduce over
apply_over = _select_red_axes(weight_axes)
# Apply downselections and redistribute for reduction
weight = _downselect(data.weight)

if self.weight_op in {"all", "any"}:
# Convert to boolean before applying the op
w = weight > 0.0
else:
w = weight

# Get the indices of the axes that we want to reduce over
# The dataset must contain all the axes that we want to
# apply the reduction over
apply_over = []
for ax in self.axes:
try:
apply_over.append(ds_axes.index(ax))
except ValueError as e:
raise ValueError(f"Axis {ax} not found in dataset {name}.") from e
# Apply the weight reduction
ds_weight = self._weight_op_func(w, axis=apply_over, keepdims=True)

apply_over = tuple(apply_over)
if out.weight.distributed:
# Distribute back to the original axis
ds_weight = ds_weight.redistribute(original_ax_id)

# Get a view of the underlying array
arr = ds[:]
out.weight[:] = ds_weight[:]
else:
# Flat weight that can broadcast to whatever shape
self.log.info("No weights available. Using equal weighting.")
weight = np.ones(1, dtype=bool)

# Iterate over the datasets and reduce
for name in self.datasets:
# Get the dataset
ds = data.datasets[name]
# Initialize the dataset in the output container
out.add_dataset(name)
# Get the current distributed axis of this dataset, if it
# is distributed
if ds.distributed:
original_ax_id = ds.distributed_axis

if self.generate_mask:
# Generate a mask from the weights
mask = ~(data.weight[:] > 0.0)
# Broadcast the mask to the array shape
# Get the axes for this dataset
ds_axes = ds.attrs["axis"]
# Get the axes in this dataset to reduce over
apply_over = _select_red_axes(ds_axes)
# Apply downselections and redistribute
arr = _downselect(ds)
# If arr is distributed, get its axis and a numpy view
if isinstance(arr, mpiarray.MPIArray):
# Get the new distributed axis
new_ax = arr.axis
arr = arr.local_array[:]

if name in self.apply_weight or self.mask_array:
# If we need to use the weights, make sure they're distributed
# to the correct axis and can broadcast to this array
if isinstance(weight, mpiarray.MPIArray):
# Redistribute to the equivalent axis in the weights
ax_name = ds_axes[new_ax]
new_weight_ax = weight_axes.index(ax_name)
weight = weight.redistribute(new_weight_ax)

# Broadcast the weights to the array shape
# MPIArrays aren't preserved properly, so must return a ndarray
try:
mask = np.broadcast_to(mask, arr.shape, subok=False)
mask = np.broadcast_to(weight, arr.shape, subok=False)
except ValueError as e:
raise ValueError(
f"Got dataset {name} with shape {arr.shape} and weight "
f"with shape {mask.shape}. If using a weight mask, all "
"datasets must have the same shape/axes as the weights, "
"or the weights must be able to be broadcasted."
f"Could not broadcast weights. Got dataset {name} with "
f"shape {arr.shape} and weight with shape {weight.shape}."
) from e
else:
mask = np.zeros(arr.shape, dtype=bool)

if ds.distributed:
# Select the best axes to redistribute to
original_ax_id = ds.distributed_axis
# _select_axes takes the name of the current axis, not the index
sel_ax_id, red_ax_id = _select_axes(ds_axes, ds_axes[original_ax_id])
arr = arr.redistribute(sel_ax_id)
mask = mpiarray.MPIArray.wrap(mask, axis=original_ax_id).redistribute(
sel_ax_id
)

# Apply downselections to the dataset and the mask
for ax, sel in self.selections.items():
try:
ax_ind = ds_axes.index(ax)
except ValueError:
continue
arr = _apply_sel(arr, sel, ax_ind)
mask = _apply_sel(mask, sel, ax_ind)

if ds.distributed:
# Distribute to the axis used for reduction
arr = arr.redistribute(red_ax_id)
mask = mask.redistribute(red_ax_id)
if name in self.apply_weight:
# Apply the weights to the dataset
arr = arr * mask

# Apply the optional mask and reduction op
# Not all reductions play nicely with MPIArrays, and we can't use a
# masked array, so we have to use a ndarray view of the local array.
# Don't use the `MPIArray.local_array` property for this, since we
# could also have a non-distributed dataset
if self.apply_mask:
arr = np.ma.array(arr.view(np.ndarray), mask=mask.view(np.ndarray))
reduced = op(arr, axis=apply_over, keepdims=True).data
if self.mask_array:
# Use a masked array, ignoring values where the weights are zero
arr = np.ma.array(arr, mask=~(mask > 0.0))
reduced = self._op_func(arr, axis=apply_over, keepdims=True).data
else:
arr = arr.view(np.ndarray)
reduced = op(arr, axis=apply_over, keepdims=True)
# The weights provided are just a boolean mask which is zero wherever
# all values along the reduction axes are zero and one elsewhere
weight = np.all(~mask, axis=apply_over, keepdims=True)
reduced = self._op_func(arr, axis=apply_over, keepdims=True)

if ds.distributed:
reduced = mpiarray.MPIArray.wrap(reduced, axis=red_ax_id)
# Redistribute back to the original axis
reduced = mpiarray.MPIArray.wrap(reduced, axis=new_ax)
reduced = reduced.redistribute(original_ax_id)
weight = weight.redistribute(original_ax_id)

out[name][:] = reduced[:]

# Add the weights back in. If a weight mask was used, the resulting
# weights array will be the same for each dataset, so we don't have
# to care about overwriting it each time
if self.generate_mask:
out.weight[:] = weight[:]
else:
out.weight[:] = 1

return out

0 comments on commit 9fd3e58

Please sign in to comment.