Skip to content

Commit

Permalink
[pallas] Added API docs for Triton and Mosaic GPU backends
Browse files Browse the repository at this point in the history
I've left the TPU backend docs a stub for now. Hopefully, someone working
on Pallas TPU can fill them in later.
  • Loading branch information
superbobry committed Oct 10, 2024
1 parent 351187d commit 46e65b5
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 8 deletions.
41 changes: 41 additions & 0 deletions docs/jax.experimental.pallas.mosaic_gpu.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
``jax.experimental.pallas.mosaic_gpu`` module
=============================================

.. automodule:: jax.experimental.pallas.mosaic_gpu

Classes
-------

.. autosummary::
:toctree: _autosummary

Barrier
GPUBlockSpec
GPUCompilerParams
GPUMemorySpace
TilingTransform
TransposeTransform
WGMMAAccumulatorRef

Functions
---------

.. autosummary::
:toctree: _autosummary

copy_gmem_to_smem
copy_smem_to_gmem
wait_barrier
wait_smem_to_gmem
wgmma
wgmma_wait

Aliases
-------

.. autosummary::
:toctree: _autosummary

ACC
GMEM
SMEM
18 changes: 16 additions & 2 deletions docs/jax.experimental.pallas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@

.. automodule:: jax.experimental.pallas

Backends
--------

.. toctree::
:maxdepth: 1

jax.experimental.pallas.mosaic_gpu
jax.experimental.pallas.triton
jax.experimental.pallas.tpu

Classes
-------

Expand Down Expand Up @@ -36,7 +46,11 @@ Functions
atomic_min
atomic_or
atomic_xchg

atomic_xor
broadcast_to
debug_print

dot
max_contiguous
multiple_of
run_scoped
when
16 changes: 16 additions & 0 deletions docs/jax.experimental.pallas.tpu.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
``jax.experimental.pallas.tpu`` module
======================================

.. automodule:: jax.experimental.pallas.tpu

Classes
-------

.. autosummary::
:toctree: _autosummary

Functions
---------

.. autosummary::
:toctree: _autosummary
22 changes: 22 additions & 0 deletions docs/jax.experimental.pallas.triton.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
``jax.experimental.pallas.triton`` module
=========================================

.. automodule:: jax.experimental.pallas.triton

Classes
-------

.. autosummary::
:toctree: _autosummary

TritonCompilerParams

Functions
---------

.. autosummary::
:toctree: _autosummary

approx_tanh
debug_barrier
elementwise_inline_asm
12 changes: 6 additions & 6 deletions docs/jax.experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ Experimental Modules

jax.experimental.array_api
jax.experimental.checkify
jax.experimental.pjit
jax.experimental.sparse
jax.experimental.jet
jax.experimental.custom_partitioning
jax.experimental.multihost_utils
jax.experimental.compilation_cache
jax.experimental.custom_partitioning
jax.experimental.jet
jax.experimental.key_reuse
jax.experimental.mesh_utils
jax.experimental.multihost_utils
jax.experimental.pallas
jax.experimental.pjit
jax.experimental.serialize_executable
jax.experimental.shard_map
jax.experimental.pallas
jax.experimental.sparse

Experimental APIs
-----------------
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ class GPUCompilerParams(pallas_core.CompilerParams):


class GPUMemorySpace(enum.Enum):
#: Global memory.
GMEM = "gmem"
#: Shared memory.
SMEM = "smem"
#: Registers.
REGS = "regs"

def __str__(self) -> str:
Expand Down
3 changes: 3 additions & 0 deletions jax/experimental/pallas/mosaic_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
from jax._src.pallas.mosaic_gpu.core import TilingTransform
from jax._src.pallas.mosaic_gpu.core import TransposeTransform
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem
Expand All @@ -31,5 +32,7 @@
from jax._src.pallas.mosaic_gpu.primitives import wgmma
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait

#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`.
GMEM = GPUMemorySpace.GMEM
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM`.
SMEM = GPUMemorySpace.SMEM

0 comments on commit 46e65b5

Please sign in to comment.