Skip to content

Commit

Permalink
fix: get basic global indexing working
Browse files Browse the repository at this point in the history
  • Loading branch information
anjakefala committed Feb 22, 2021
1 parent 57da2cb commit 8165556
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions caput/mpiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,11 @@ def __getitem__(self, slobj):
slobj = tuple(slice(None, None, None) if sl is None else sl for sl in slobj)

# Return an MPIArray view
arr = self.array[slobj]

# Figure out which is the distributed axis after the slicing, by
# removing slice axes which are just ints from the mapping
dist_axis = [
index for index, sl in enumerate(slobj) if not isinstance(sl, int)
].index(self.axis)
dist_axis = [index for index, sl in enumerate(slobj) if not isinstance(sl, int)].index(self.axis)

return MPIArray.wrap(arr, axis=dist_axis, comm=self.array._comm)
return MPIArray.wrap(self.array.view(np.ndarray)[slobj], axis=dist_axis, comm=self.array._comm)

def __setitem__(self, slobj, value):

Expand Down Expand Up @@ -1159,7 +1155,6 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if reduction_axis < dist_axis:
dist_axis -= 1

# Wrapping the results back into MPIArrays, distributed across the appropriate axis
ret = []

for result, output in zip(results, outputs):
Expand Down

0 comments on commit 8165556

Please sign in to comment.