Skip to content

Commit

Permalink
deal with warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
jrs65 committed May 30, 2022
1 parent 80f19d5 commit 783c084
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions caput/tests/test_mpiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def test_global_getslice():

# Initialise the distributed array
for li, _ in darr.enumerate(axis=0):
darr[li] = 10 * (10 * rank + li) + np.arange(20)
darr.local_array[li] = 10 * (10 * rank + li) + np.arange(20)

# Construct numpy array which should be equivalent to the global array
whole_array = (
Expand Down Expand Up @@ -358,17 +358,20 @@ def test_global_getslice():
assert dslice.local_shape == (10, 5)
assert dslice.axis == 1

# Check that directly indexing into distributed axis returns a numpy array equal to local array indexing
# Check that directly indexing into distributed axis returns a numpy array equal to
# local array indexing
darr = mpiarray.MPIArray((size,), axis=0)
assert (darr[0] == darr.local_array[0]).all()
with pytest.warns(UserWarning):
assert (darr[0] == darr.local_array[0]).all()

# Check that a single index into a non-parallel axis works
darr = mpiarray.MPIArray((4, size), axis=1)
darr[:] = rank
assert (darr[0] == rank).all()
assert darr[0].axis == 0
# check that direct slicing into distributed axis returns a numpy array for local array slicing
assert (darr[2, 0] == darr.local_array[2, 0]).all()
with pytest.warns(UserWarning):
assert (darr[2, 0] == darr.local_array[2, 0]).all()

darr = mpiarray.MPIArray((20, size * 5), axis=1)
darr[:] = rank
Expand Down

0 comments on commit 783c084

Please sign in to comment.