Skip to content

Commit

Permalink
feat(MPIArray): add gather and allgather for collecting the full array
Browse files Browse the repository at this point in the history
  • Loading branch information
jrs65 committed Feb 20, 2020
1 parent b32ad1d commit 2da95ae
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
86 changes: 86 additions & 0 deletions caput/mpiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,92 @@ def copy(self):
self.view(np.ndarray).copy(), axis=self.axis, comm=self.comm
)

def gather(self, rank=0):
"""Gather a full copy onto a specific rank.
Parameters
----------
rank : int, optional
Rank to gather onto. Default is rank=0
Returns
-------
arr : np.ndarray, or None
The full global array on the specified rank.
"""
if self.comm.rank == rank:
arr = np.ndarray(self.global_shape, dtype=self.dtype)
else:
arr = None

splits = mpiutil.split_all(self.global_shape[self.axis], self.comm)

for ri, (n, s, e) in enumerate(zip(*splits)):

if self.comm.rank == rank:

# Construct a temporary array for the data to be received into
tshape = list(self.global_shape)
tshape[self.axis] = n
tbuf = np.ndarray(tshape, dtype=self.dtype)

# Set up the non-blocking receive request
request = self.comm.Irecv(tbuf, source=ri)

# Send the data
if self.comm.rank == ri:
self.comm.Isend(self.view(np.ndarray), dest=rank)

if self.comm.rank == rank:

# Wait until the data has arrived
stat = mpiutil.MPI.Status()
request.Wait(status=stat)

if stat.error != mpiutil.MPI.SUCCESS:
print(
"**** ERROR in MPI RECV (source: %i, dest rank: %i) *****"
% (ri, rank)
)

# Put the data into the correct location
dest_slice = [slice(None)] * len(self.shape)
dest_slice[self.axis] = slice(s, e)
arr[dest_slice] = tbuf

return arr

def allgather(self):
"""Gather a full copy onto each rank.
Returns
-------
arr : np.ndarray
The full global array.
"""
arr = np.ndarray(self.global_shape, dtype=self.dtype)

splits = mpiutil.split_all(self.global_shape[self.axis], self.comm)

for ri, (n, s, e) in enumerate(zip(*splits)):

# Construct a temporary array for the data to be received into
tshape = list(self.global_shape)
tshape[self.axis] = n
tbuf = np.ndarray(tshape, dtype=self.dtype)

if self.comm.rank == ri:
tbuf[:] = self

self.comm.Bcast(tbuf, root=ri)

# Copy the array into the correct place
dest_slice = [slice(None)] * len(self.shape)
dest_slice[self.axis] = slice(s, e)
arr[dest_slice] = tbuf

return arr

def _to_hdf5_serial(self, filename, dataset, create=False):
"""Write into an HDF5 dataset.
Expand Down
23 changes: 23 additions & 0 deletions caput/tests/test_mpiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,29 @@ def test_redistribution(self):
arr3 = arr.redistribute(axis=5)
assert (arr3 == garr[:, :, :, :, :, s2:e2]).view(np.ndarray).all()

def test_gather(self):

rank = mpiutil.rank
size = mpiutil.size
block = 2

global_shape = (2, 3, size * block)
global_array = np.zeros(global_shape, dtype=np.float64)
global_array[..., :] = np.arange(size * block)

arr = mpiarray.MPIArray(global_shape, dtype=np.float64, axis=2)
arr[:] = global_array[..., (rank * block) : ((rank + 1) * block)]

assert (arr.allgather() == global_array).all()

gather_rank = 1 if size > 1 else 0
ga = arr.gather(rank=gather_rank)

if rank == gather_rank:
assert (ga == global_array).all()
else:
assert ga is None

def test_wrap(self):

ds = mpiarray.MPIArray((10, 17))
Expand Down

0 comments on commit 2da95ae

Please sign in to comment.