Skip to content

Commit

Permalink
docs(mpiarray): more documentation addition
Browse files Browse the repository at this point in the history
  • Loading branch information
anjakefala committed Feb 4, 2021
1 parent 8210548 commit bcbd75f
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions caput/mpiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
'''

args = []
input_mpi = []
input_mpi = [] # container for original input MPIArrays
distr_axis = None # the distributed axis

# convert all local arrays into ndarrays
Expand All @@ -1028,7 +1028,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
else:
assert (
distr_axis == input_.axis
), "The distributed axis for all MPIArrays in an exp should be the same"
), "The distributed axis for all MPIArrays in an expression should be the same"
input_mpi.append(input_)
args.append(input_.local_array.view(np.ndarray))
else:
Expand All @@ -1040,6 +1040,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
f"operations along the distributed axis (in this case, {mpi_array.axis}) are not allowed."
)

# 'out' kwargs contain arrays that the ufunc places the results into
# this broadcasts the local part of the output arrays into an ndarray
# that the ufunc knows how to work with
outputs = kwargs.get("out", None)
if outputs:
out_args = []
Expand Down Expand Up @@ -1067,12 +1070,12 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
ret = []

for result, output in zip(results, outputs):
if output is not None: # results were placed in the array specified by `out`
if output is not None: # case: results were placed in the array specified by `out`; return as is
ret.append(output)
else:
if result.shape: # then result is an array; convert it into an MPIArray
if result.shape: # case: the result is an array; convert back it into an MPIArray
ret.append(MPIArray.wrap(result, axis=distr_axis))
else: # result is a scalar, return as is
else: # case: result is a scalar; return as is
ret.append(result)

return ret[0] if len(ret) == 1 else tuple(ret)
Expand All @@ -1083,19 +1086,17 @@ def __array_finalize__(self, obj):
if obj is None:
return

# we are in a ufunc, use the attributes from the original MPIArray
# what *will* the attributes actually be? I need other use-cases
# why did obj lose its attributes?
# we are in a ufunc, rebuild the attributes from the original MPIArray
if isinstance(obj, MPIArray):
comm = getattr(obj, "comm", mpiutil.world)
axis = getattr(obj, "axis", 0)
axis = getattr(obj, "axis", 0) # probably not a good default! How would we find this out?

# get axis length
axlen = self.shape[axis]
totallen = mpiutil.allreduce(axlen, comm=comm)

# Figure out what the distributed layout is
local_num, local_start, local_end = mpiutil.split_local(totallen, comm=comm)
_, local_start, _ = mpiutil.split_local(totallen, comm=comm)

# Get shape and offset
lshape = self.shape
Expand Down

0 comments on commit bcbd75f

Please sign in to comment.