Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalar type promotion not working #3128

Open
nsmith- opened this issue May 24, 2024 · 2 comments
Open

Scalar type promotion not working #3128

nsmith- opened this issue May 24, 2024 · 2 comments
Labels
bug (unverified) The problem described would be a bug, but needs to be triaged

Comments

@nsmith-
Copy link
Contributor

nsmith- commented May 24, 2024

Version of Awkward Array

2.6.4

Description and code to reproduce

In the following code

from typing import Annotated
import numpy as np
import awkward as ak
from enum import IntEnum

class ParticleOrigin(IntEnum):
    NonDefined: int = 0
    SingleElec: int = 1
    SingleMuon: int = 2


# works as expected
print(np.arange(10) == ParticleOrigin.SingleElec)
# errors
print(ak.Array(np.arange(10)) == ParticleOrigin.SingleElec)

numpy manages to recognize the IntEnum is promotable to int64 but awkward fails with the error:

Traceback (most recent call last):
  File "/Users/ncsmith/src/tmp.py", line 16, in <module>
    print(ak.Array(np.arange(10)) == ParticleOrigin.SingleElec)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_operators.py", line 53, in func
    return ufunc(self, other)
           ^^^^^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/highlevel.py", line 1516, in __array_ufunc__
    return ak._connect.numpy.array_ufunc(ufunc, method, inputs, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_connect/numpy.py", line 466, in array_ufunc
    out = ak._broadcasting.broadcast_and_apply(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_broadcasting.py", line 968, in broadcast_and_apply
    out = apply_step(
          ^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_broadcasting.py", line 946, in apply_step
    return continuation()
           ^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_broadcasting.py", line 915, in continuation
    return broadcast_any_list()
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_broadcasting.py", line 622, in broadcast_any_list
    outcontent = apply_step(
                 ^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_broadcasting.py", line 928, in apply_step
    result = action(
             ^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_connect/numpy.py", line 432, in action
    result = backend.nplike.apply_ufunc(ufunc, method, input_args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_nplikes/array_module.py", line 208, in apply_ufunc
    return self._apply_ufunc_nep_50(ufunc, method, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ncsmith/src/commonenv/lib/python3.12/site-packages/awkward/_nplikes/array_module.py", line 235, in _apply_ufunc_nep_50
    resolved_dtypes = ufunc.resolve_dtypes(arg_dtypes)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Provided dtype must be a valid NumPy dtype, int, float, complex, or None.

This error occurred while calling

    numpy.equal.__call__(
        <Array [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] type='10 * int64'>
        <ParticleOrigin.SingleElec: 1>
    )

cc @kratsg

@nsmith- nsmith- added the bug (unverified) The problem described would be a bug, but needs to be triaged label May 24, 2024
@agoose77
Copy link
Collaborator

agoose77 commented May 24, 2024

This should be supported, but currently fails. Even the Array API (which we don't promise to confirm to, but take as inspiration on the promotion rules) supports this: https://data-apis.org/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars

I will action this probably over the weekend.

@kratsg
Copy link

kratsg commented May 24, 2024

Note that

print(ak.Array(np.arange(10)) == ParticleOrigin.SingleElec.value)

still works (as in, regular Enums are seemingly fine).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug (unverified) The problem described would be a bug, but needs to be triaged
Projects
None yet
Development

No branches or pull requests

3 participants