Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unimplemented primitive in Pallas GPU lowering: dynamic_slice #24145

Open
chaoming0625 opened this issue Oct 6, 2024 · 3 comments
Open

Unimplemented primitive in Pallas GPU lowering: dynamic_slice #24145

chaoming0625 opened this issue Oct 6, 2024 · 3 comments
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@chaoming0625
Copy link

Description

I wrote an operator using the following syntax:

    def body(j):
      def true_fn():
        y_ref[ind_ref[j, i_start: i_end]] += 1.0
      jax.lax.cond(bool_vec[j], true_fn, lambda: None)
      return j + 1

    jax.lax.while_loop(lambda j: j < length, body, 0)

But I got the error:

Unimplemented primitive in Pallas GPU lowering: dynamic_slice.

I do not know why a GPU could not handle it.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34
jaxlib: 0.4.34
numpy:  1.26.4
python: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='niplab-ubuntu22-0', release='6.8.0-40-generic', version='#40~22.04.3-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 30 17:30:19 UTC 2', machine='x86_64')



@chaoming0625 chaoming0625 added the bug Something isn't working label Oct 6, 2024
@superbobry superbobry added the pallas Issues pertaining to Pallas (GPU or TPU) label Oct 7, 2024
@superbobry
Copy link
Collaborator

Looking at triton-lang/triton#656, we will not be able to support this in Pallas GPU, because Triton itself doesn't support dynamic slices.

@superbobry
Copy link
Collaborator

superbobry commented Oct 7, 2024

Actually, thinking about this a bit more, maybe we can support dynamic_slice in a restricted way. Can you share the full reproducer, please? The snippet above is missing a few definitions.

@ayaka14732
Copy link
Member

There are issues for slice #19214 and TPU #18897 as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

No branches or pull requests

3 participants