-
Notifications
You must be signed in to change notification settings - Fork 85
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: adding support for numpy.real, imag, round, angle, (#3053)
* 2917: Adding ufunc-like functions real, imag, and angle. These need unit tests still. * Adding unit tests for real, imag, angle. * More unit tests, fixing bugs Fixing bugs in TypeTracer.angle. Simplifying test_complex_ops. Adding tests of TypeTracer real, imag, angle. * Adding round function (almost numpy.ufunc) * style: pre-commit fixes * Fix dependence on numpy type conversions As per Jim's recommendation, for real, imag, angle. * Splitting real, imag, angle * Adding to mostly-augen docs, removing round(.., out=...) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jim Pivarski <jpivarski@users.noreply.github.com>
- Loading branch information
1 parent
b45c312
commit a6444b0
Showing
11 changed files
with
422 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE | ||
|
||
from __future__ import annotations | ||
|
||
import awkward as ak | ||
from awkward._backends.numpy import NumpyBackend | ||
from awkward._dispatch import high_level_function | ||
from awkward._layout import HighLevelContext | ||
from awkward._nplikes.numpy_like import NumpyMetadata | ||
|
||
__all__ = ("angle",) | ||
|
||
np = NumpyMetadata.instance() | ||
cpu = NumpyBackend.instance() | ||
|
||
|
||
@ak._connect.numpy.implements("angle") | ||
@high_level_function() | ||
def angle(val, deg=False, highlevel=True, behavior=None, attrs=None): | ||
""" | ||
Args: | ||
val : array_like | ||
Input array. | ||
deg (bool, default is False): If True, returns angles in degrees, | ||
otherwise in radians. | ||
highlevel (bool, default is True): If True, return an #ak.Array; | ||
otherwise, return a low-level #ak.contents.Content subclass. | ||
behavior (None or dict): Custom #ak.behavior for the output array, if | ||
high-level. | ||
attrs (None or dict): Custom attributes for the output array, if | ||
high-level. | ||
Returns the counterclockwise angle from the positive real axis on the complex | ||
plane in the range ``(-pi, pi]``, with dtype as a float. | ||
""" | ||
# Dispatch | ||
yield (val,) | ||
|
||
# Implementation | ||
return _impl_angle(val, deg, highlevel, behavior, attrs) | ||
|
||
|
||
def _impl_angle(val, deg, highlevel, behavior, attrs): | ||
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: | ||
layout = ctx.unwrap(val, allow_record=False, primitive_policy="error") | ||
|
||
# A closure over deg: | ||
def action_angle(layout, backend, **kwargs): | ||
if isinstance(layout, ak.contents.NumpyArray): | ||
return ak.contents.NumpyArray(backend.nplike.angle(layout.data, deg)) | ||
else: | ||
return None | ||
|
||
out = ak._do.recursively_apply(layout, action_angle) | ||
return ctx.wrap(out, highlevel=highlevel) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE | ||
|
||
from __future__ import annotations | ||
|
||
import awkward as ak | ||
from awkward._backends.numpy import NumpyBackend | ||
from awkward._dispatch import high_level_function | ||
from awkward._layout import HighLevelContext | ||
from awkward._nplikes.numpy_like import NumpyMetadata | ||
|
||
__all__ = ("imag",) | ||
|
||
np = NumpyMetadata.instance() | ||
cpu = NumpyBackend.instance() | ||
|
||
|
||
@ak._connect.numpy.implements("imag") | ||
@high_level_function() | ||
def imag(val, highlevel=True, behavior=None, attrs=None): | ||
""" | ||
Args: | ||
val : array_like | ||
Input array. | ||
highlevel (bool, default is True): If True, return an #ak.Array; | ||
otherwise, return a low-level #ak.contents.Content subclass. | ||
behavior (None or dict): Custom #ak.behavior for the output array, if | ||
high-level. | ||
attrs (None or dict): Custom attributes for the output array, if | ||
high-level. | ||
Returns the imaginary components of the given array elements. | ||
If the arrays have complex elements, the returned arrays are floats. | ||
""" | ||
# Dispatch | ||
yield (val,) | ||
|
||
# Implementation | ||
return _impl_imag(val, highlevel, behavior, attrs) | ||
|
||
|
||
def _impl_imag(val, highlevel, behavior, attrs): | ||
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: | ||
layout = ctx.unwrap(val, allow_record=False, primitive_policy="error") | ||
|
||
out = ak._do.recursively_apply(layout, _action_imag) | ||
return ctx.wrap(out, highlevel=highlevel) | ||
|
||
|
||
def _action_imag(layout, backend, **kwargs): | ||
if isinstance(layout, ak.contents.NumpyArray): | ||
return ak.contents.NumpyArray(backend.nplike.imag(layout.data)) | ||
else: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE | ||
|
||
from __future__ import annotations | ||
|
||
import awkward as ak | ||
from awkward._backends.numpy import NumpyBackend | ||
from awkward._dispatch import high_level_function | ||
from awkward._layout import HighLevelContext | ||
from awkward._nplikes.numpy_like import NumpyMetadata | ||
|
||
__all__ = ("real",) | ||
|
||
np = NumpyMetadata.instance() | ||
cpu = NumpyBackend.instance() | ||
|
||
|
||
@ak._connect.numpy.implements("real") | ||
@high_level_function() | ||
def real(val, highlevel=True, behavior=None, attrs=None): | ||
""" | ||
Args: | ||
val : array_like | ||
Input array. | ||
highlevel (bool, default is True): If True, return an #ak.Array; | ||
otherwise, return a low-level #ak.contents.Content subclass. | ||
behavior (None or dict): Custom #ak.behavior for the output array, if | ||
high-level. | ||
attrs (None or dict): Custom attributes for the output array, if | ||
high-level. | ||
Returns the real components of the given array elements. | ||
If the arrays have complex elements, the returned arrays are floats. | ||
""" | ||
# Dispatch | ||
yield (val,) | ||
|
||
# Implementation | ||
return _impl_real(val, highlevel, behavior, attrs) | ||
|
||
|
||
def _impl_real(val, highlevel, behavior, attrs): | ||
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: | ||
layout = ctx.unwrap(val, allow_record=False, primitive_policy="error") | ||
|
||
out = ak._do.recursively_apply(layout, _action_real) | ||
return ctx.wrap(out, highlevel=highlevel) | ||
|
||
|
||
def _action_real(layout, backend, **kwargs): | ||
if isinstance(layout, ak.contents.NumpyArray): | ||
return ak.contents.NumpyArray(backend.nplike.real(layout.data)) | ||
else: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE | ||
|
||
from __future__ import annotations | ||
|
||
import awkward as ak | ||
from awkward._connect.numpy import UNSUPPORTED | ||
from awkward._dispatch import high_level_function | ||
from awkward._layout import HighLevelContext | ||
from awkward._nplikes.numpy_like import NumpyMetadata | ||
|
||
__all__ = ("round",) | ||
|
||
np = NumpyMetadata.instance() | ||
|
||
|
||
@ak._connect.numpy.implements("round") | ||
@high_level_function() | ||
def round( | ||
array, | ||
decimals: int = 0, | ||
out=UNSUPPORTED, | ||
highlevel=True, | ||
behavior=None, | ||
attrs=None, | ||
): | ||
""" | ||
Args: | ||
array : array_like | ||
Input array. | ||
decimals : int, optional | ||
Number of decimal places to round to (default: 0). If | ||
decimals is negative, it specifies the number of positions to | ||
the left of the decimal point. | ||
out : unsupported optional argument | ||
highlevel (bool, default is True): If True, return an #ak.Array; | ||
otherwise, return a low-level #ak.contents.Content subclass. | ||
behavior (None or dict): Custom #ak.behavior for the output array, if | ||
high-level. | ||
attrs (None or dict): Custom attributes for the output array, if | ||
high-level. | ||
Returns the real components of the given array elements. | ||
If the arrays have complex elements, the returned arrays are floats. | ||
""" | ||
# Dispatch | ||
yield (array,) | ||
|
||
# Implementation | ||
return _impl(array, decimals, highlevel, behavior, attrs) | ||
|
||
|
||
def _impl(array, decimals, highlevel, behavior, attrs): | ||
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: | ||
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") | ||
|
||
# A closure over deg: | ||
def action(layout, backend, **kwargs): | ||
if isinstance(layout, ak.contents.NumpyArray): | ||
return ak.contents.NumpyArray(backend.nplike.round(layout.data, decimals)) | ||
else: | ||
return None | ||
|
||
out = ak._do.recursively_apply(layout, action) | ||
return ctx.wrap(out, highlevel=highlevel) |
Oops, something went wrong.