Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up CollateProducts by avoiding redistribute when possible. #191

Merged
merged 2 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions draco/analysis/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def process(self, ss):
)

bt_freq = ss.index_map["freq"][freq_ind]

# Determine the input product map and conjugation.
# If the input timestream is already stacked, then attempt to redefine
# its representative products so that they contain only feeds that exist
Expand Down Expand Up @@ -226,13 +225,19 @@ def process(self, ss):
**output_kwargs,
)

# Check if frequencies are already ordered
no_redistribute = freq_ind == list(range(len(ss.freq[:])))

# Add gain dataset.
# if 'gain' in ss.datasets:
# sp.add_dataset('gain')

# Ensure all frequencies and products are on each node
ss.redistribute(["ra", "time"])
sp.redistribute(["ra", "time"])
# If frequencies are mapped across ranks, we have to redistribute so all
# frequencies and products are on each rank
raxis = "freq" if no_redistribute else ["ra", "time"]
self.log.info(f"Distributing across '{raxis}' axis")
ss.redistribute(raxis)
sp.redistribute(raxis)

# Initialize datasets in output container
sp.vis[:] = 0.0
Expand All @@ -257,21 +262,20 @@ def process(self, ss):
if self.weight == "uniform":
nprod_in_stack = (nprod_in_stack > 0).astype(np.float32)

# Find the local times (necessary because nprod_in_stack is not distributed)
ntt = ss.vis.local_shape[-1]
stt = ss.vis.local_offset[-1]
ett = stt + ntt

# Create counter to increment during the stacking.
# This will be used to normalize at the end.
counter = np.zeros_like(sp.weight[:])

# Dereference the global slices now, there's a hidden MPI call in the [:] operation.
# Dereference the global slices, there's a hidden MPI call in the [:] operation.
spv = sp.vis[:]
ssv = ss.vis[:]
spw = sp.weight[:]
ssw = ss.weight[:]

# Get the local frequency and time slice/mapping
freq_ind = slice(None) if no_redistribute else freq_ind
time_ind = slice(None) if no_redistribute else ssv.local_bounds

# Iterate over products (stacked) in the sidereal stream
for ss_pi, ((ii, ij), conj) in enumerate(zip(ss_prod, ss_conj)):

Expand All @@ -291,34 +295,39 @@ def process(self, ss):

# Generate weight
if self.weight == "inverse_variance":
wss = ssw[freq_ind, ss_pi]
wss = ssw.local_array[freq_ind, ss_pi]

else:
wss = (ssw[freq_ind, ss_pi] > 0.0).astype(np.float32)
wss.local_array[:] *= nprod_in_stack[np.newaxis, ss_pi, stt:ett]
wss = (ssw.local_array[freq_ind, ss_pi] > 0.0).astype(np.float32)
wss[:] *= nprod_in_stack[np.newaxis, ss_pi, time_ind]

# Accumulate visibilities, conjugating if required
if feedconj == conj:
spv[:, sp_pi] += wss * ssv[freq_ind, ss_pi]
spv.local_array[:, sp_pi] += wss * ssv.local_array[freq_ind, ss_pi]
else:
spv[:, sp_pi] += wss * ssv[freq_ind, ss_pi].conj()
spv.local_array[:, sp_pi] += (
wss * ssv.local_array[freq_ind, ss_pi].conj()
)

# Accumulate variances in quadrature. Save in the weight dataset.
spw[:, sp_pi] += wss**2 * tools.invert_no_zero(ssw[freq_ind, ss_pi])
spw.local_array[:, sp_pi] += wss**2 * tools.invert_no_zero(
ssw.local_array[freq_ind, ss_pi]
)

# Increment counter
counter[:, sp_pi] += wss
counter.local_array[:, sp_pi] += wss

# Divide through by counter to get properly weighted visibility average
sp.vis[:] *= tools.invert_no_zero(counter)
sp.weight[:] = counter**2 * tools.invert_no_zero(sp.weight[:])

# Copy over any additional datasets that need to be frequency filtered
containers.copy_datasets_filter(
ss, sp, "freq", freq_ind, ["input", "prod", "stack"]
ss, sp, "freq", freq_ind, ["input", "prod", "stack"], allow_distributed=True
)

# Switch back to frequency distribution
# Switch back to frequency distribution. This will have minimal
# cost if we are already distributed in frequency
ss.redistribute("freq")
sp.redistribute("freq")

Expand Down
7 changes: 6 additions & 1 deletion draco/core/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2673,6 +2673,7 @@ def copy_datasets_filter(
axis: str,
selection: Union[np.ndarray, list, slice],
exclude_axes: List[str] = None,
allow_distributed: bool = False,
):
"""Copy datasets while filtering a given axis.
Expand All @@ -2689,6 +2690,10 @@ def copy_datasets_filter(
exclude_axes
An optional set of axes that if a dataset contains one means it will
not be copied.
allow_distributed, optional
Allow the filtered axis to be the distributed axis. This is ONLY
valid if filtering is occuring on the local rank only, and mainly
exists for compatibility
"""
exclude_axes_set = set(exclude_axes) if exclude_axes else set()

Expand Down Expand Up @@ -2717,7 +2722,7 @@ def copy_datasets_filter(

if isinstance(item, memh5.MemDatasetDistributed):

if item.distributed_axis == axis_ind:
if (item.distributed_axis == axis_ind) and not allow_distributed:
raise RuntimeError(
f"Cannot redistristribute dataset={item.name} along "
f"axis={axis_ind} as it is distributed."
Expand Down