Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid redistribute copy if using a single mpi process #256

Merged
merged 1 commit into from
May 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 60 additions & 66 deletions caput/mpiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,13 @@ def redistribute(self, axis: int) -> "MPIArray":
if self.axis == axis or self.comm is None:
return self

# Avoid repeat mpi property calls
csize = self.comm.size
crank = self.comm.rank

if csize == 1:
return MPIArray.wrap(self.local_array, axis, self.comm)

# Check to make sure there is enough memory to perform the redistribution.
# Must be able to allocate the target array and 2 buffers. We allocate
# slightly more space than needed to be safe
Expand All @@ -753,76 +760,63 @@ def redistribute(self, axis: int) -> "MPIArray":
# Get views into local and target arrays
arr = self.local_array
target_arr = dist_arr.local_array
# Avoid repeat mpi property calls
csize = self.comm.size
crank = self.comm.rank

if csize == 1:
if arr.shape[self.axis] == self.global_shape[self.axis]:
# We are working on a single node.
target_arr[:] = arr
else:
raise ValueError(
f"Global shape {self.global_shape} is incompatible with local "
f"array shape {self.shape}"
)
else:
# Get the start and end of each subrange of interest
_, sac, eac = mpiutil.split_all(self.global_shape[axis], self.comm)
_, sar, ear = mpiutil.split_all(self.global_shape[self.axis], self.comm)
# Split the soruce array into properly sized blocks for sending
blocks = np.array_split(arr, np.insert(eac, 0, sac[0]), axis)[1:]
# Create fixed-size contiguous buffers for sending and receiving
buffer_shape = list(target_arr.shape)
buffer_shape[self.axis] = max(ear - sar)
buffer_shape[axis] = max(eac - sac)
# Pre-allocate buffers and buffer type
recv_buffer = np.empty(buffer_shape, dtype=self.dtype)
send_buffer = np.empty_like(recv_buffer)
buf_type = self._prep_buf(send_buffer)[1]

# Empty slices for target, send buf, recv buf
targetsl = [slice(None)] * len(buffer_shape)
sendsl = [slice(None)] * len(buffer_shape)
recvsl = [slice(None)] * len(buffer_shape)
# Send and recv buf have some fixed axis slices per rank
sendsl[self.axis] = slice(ear[crank] - sar[crank])
recvsl[axis] = slice(eac[crank] - sac[crank])

mpistatus = mpiutil.MPI.Status()

# Cyclically pass messages forward to i adjacent rank
for i in range(csize):
send_to = (crank + i) % csize
recv_from = (crank - i) % csize

# Write send data into send buffer location
sendsl[axis] = slice(eac[send_to] - sac[send_to])
send_buffer[tuple(sendsl)] = blocks[send_to]

self.comm.Sendrecv(
sendbuf=(send_buffer, buf_type),
dest=send_to,
sendtag=(csize * crank + send_to),
recvbuf=(recv_buffer, buf_type),
source=recv_from,
recvtag=(csize * recv_from + crank),
status=mpistatus,
)
# Get the start and end of each subrange of interest
_, sac, eac = mpiutil.split_all(self.global_shape[axis], self.comm)
_, sar, ear = mpiutil.split_all(self.global_shape[self.axis], self.comm)
# Split the soruce array into properly sized blocks for sending
blocks = np.array_split(arr, np.insert(eac, 0, sac[0]), axis)[1:]
# Create fixed-size contiguous buffers for sending and receiving
buffer_shape = list(target_arr.shape)
buffer_shape[self.axis] = max(ear - sar)
buffer_shape[axis] = max(eac - sac)
# Pre-allocate buffers and buffer type
recv_buffer = np.empty(buffer_shape, dtype=self.dtype)
send_buffer = np.empty_like(recv_buffer)
buf_type = self._prep_buf(send_buffer)[1]

# Empty slices for target, send buf, recv buf
targetsl = [slice(None)] * len(buffer_shape)
sendsl = [slice(None)] * len(buffer_shape)
recvsl = [slice(None)] * len(buffer_shape)
# Send and recv buf have some fixed axis slices per rank
sendsl[self.axis] = slice(ear[crank] - sar[crank])
recvsl[axis] = slice(eac[crank] - sac[crank])

mpistatus = mpiutil.MPI.Status()

# Cyclically pass and receive array chunks across ranks
for i in range(csize):
send_to = (crank + i) % csize
recv_from = (crank - i) % csize

# Write send data into send buffer location
sendsl[axis] = slice(eac[send_to] - sac[send_to])
send_buffer[tuple(sendsl)] = blocks[send_to]

self.comm.Sendrecv(
sendbuf=(send_buffer, buf_type),
dest=send_to,
sendtag=(csize * crank + send_to),
recvbuf=(recv_buffer, buf_type),
source=recv_from,
recvtag=(csize * recv_from + crank),
status=mpistatus,
)

if mpistatus.error != mpiutil.MPI.SUCCESS:
logger.error(
f"**** ERROR in MPI SEND/RECV "
f"(rank={crank}, "
f"target={send_to}, "
f"receive={recv_from}) ****"
)
if mpistatus.error != mpiutil.MPI.SUCCESS:
logger.error(
f"**** ERROR in MPI SEND/RECV "
f"(rank={crank}, "
f"target={send_to}, "
f"receive={recv_from}) ****"
)

# Write buffer into target location
targetsl[self.axis] = slice(sar[recv_from], ear[recv_from])
recvsl[self.axis] = slice(ear[recv_from] - sar[recv_from])
# Write buffer into target location
targetsl[self.axis] = slice(sar[recv_from], ear[recv_from])
recvsl[self.axis] = slice(ear[recv_from] - sar[recv_from])

target_arr[tuple(targetsl)] = recv_buffer[tuple(recvsl)]
target_arr[tuple(targetsl)] = recv_buffer[tuple(recvsl)]

return dist_arr

Expand Down
Loading