-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(mpiarray): ufunc, __array_finalize__, and __getitem__ handling for MPIArray #162
Conversation
@anjakefala we touched on this yesterday, but what I think would help a lot would be to collect a list of MPIArray calls that we would like to work and agree on what their behaviour should be (for which we might want to circulate around to chime-analysis for opinions). Then we can turn them all into unit test cases. Something like: # Setup array
comm = MPI.COMM_WORLD
dist_array = mpiarray.MPIArray((comm.size, 4), comm=comm, axis=0)
dist_array[:] = comm.rank
# I think this one is uncontroversial
assert dist_array.sum(axis=1) == 4 * comm.rank
# Should this be allowed?
# assert dist_array.sum(axis=0) == ???
# What should this do? It seems the two sensible options are that is sums over all axes, giving the same scalar everywhere...
assert dist_array.sum() == 4 * comm.size * (comm.size - 1) // 2
# ... or that it just ignores the distributed axis, giving a distributed 1-d array
assert dist_array.sum() == 4 * comm.rank |
951d27a
to
bcbd75f
Compare
This pull request introduces 1 alert when merging d2c1a05 into 4a85a3d - view on LGTM.com new alerts:
|
3131917
to
4286cee
Compare
caput/tests/test_mpiarray.py
Outdated
dist_arr_add = dist_arr + dist_arr | ||
|
||
# Check that you can add two numpy arrays, | ||
# if they are distributed along the same axes | ||
# Check that you can multiple a numpy array against a scalar | ||
assert (dist_arr_add == dist_arr_scalar).all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, note that .all()
in an MPIArray reduction over all axes, and we hadn't clearly worked out the semantics for that yet.
Seeing how it works with .all()
maybe pushes it in the direction of you should get one value per rank.
8165556
to
a5b4311
Compare
Linter is failing due to pylint updates. This PR is ready for another conversation now. |
9cfe4e2
to
9190207
Compare
"Generally accepted style in Python is to avoid staticmethods unless you have a good reason"
…aving a single element
27b2380
to
b9a80a1
Compare
😭 |
If used with draco, requires: radiocosmology/draco#125
Subclassing NumPy arrays 101
view(MPIArray)
->__array_finalize__
MPIArray[slice]
->__getitem__
->__array_finalize__
MPIArray()
->__new__
->__array_finalize__
ufuncs are the universal functions that are applied element by element in nparrays. Such a
np.add()
,np.multiply()
. If they end up summing over an axis, they are a ufunc with areduce
method. If they go element-by-element, they are a ufunc with anouter
method. If they occur in-place, they have anat
method.ufunc ->
__array_ufunc__
More links on the role all these various functions play in writing subclasses for NumPy arrays can be found here: https://github.com/chime-experiment/Pipeline/issues/81
New Exceptions
The new
AxisException
will be added. It will be raised when there are issues involving the integrity of the distributed axis with MPIArrays.getitem
global_slice
to index into a distributed axis, then it will return an array for the rank on which the index exists, andNone
otherwise.__getitem__
will return an MPIArray, whose distributed axes number might be lower than the original array, depending on if the slice results in an axis reduction. It is assumed that the distributed axis length is unchanged. Anmpiutil.split_local
call will be made.AxisException
.ufunc
standard nparrays, and then the nparray ufunc is called
AxisException
will be called, if they are not.reduce
methods.keepdims
kwargs are handled.axis
to the ufunc, it must not be the distributed axis. Operations along the distributed axis are not allowed.MPIArray.wrap(..)
is called. This means there is a call tompiutil.split_local
andmpiutil.allreduce
.array_finalize
This is called whenever a user uses
.view()
,__new__
, and broadcasts on an MPIArray. It finalizes the creation of the output MPIArray.view()
s can occur with__getitem__
calls.If it is a
__new__
call, it does nothing.If we are in an
np.ndarray.view()
call, it does nothing. This should only occur when we are within awrap()
If we are in a
view()
, it grabs the attributes from the origin array.Misc
If a user wishes to create an MPIArray from an ndarray, they should use
MPIArray.wrap()
. They should not usendarray.view(MPIArray)
.https://github.com/chime-experiment/Pipeline/issues/81