Skip to content

Commit

Permalink
[pallas] Added API docs for Triton and Mosaic GPU backends
Browse files Browse the repository at this point in the history
I've left the TPU backend docs a stub for now. Hopefully, someone working
on Pallas TPU can fill them in later.
  • Loading branch information
superbobry committed Oct 10, 2024
1 parent 351187d commit 475a992
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 1 deletion.
41 changes: 41 additions & 0 deletions docs/jax.experimental.pallas.mosaic_gpu.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
``jax.experimental.pallas.mosaic_gpu`` module
=============================================

.. automodule:: jax.experimental.pallas.mosaic_gpu

Classes
-------

.. autosummary::
:toctree: _autosummary

Barrier
GPUBlockSpec
GPUCompilerParams
GPUMemorySpace
TilingTransform
TransposeTransform
WGMMAAccumulatorRef

Functions
---------

.. autosummary::
:toctree: _autosummary

copy_gmem_to_smem
copy_smem_to_gmem
wait_barrier
wait_smem_to_gmem
wgmma
wgmma_wait

Aliases
-------

.. autosummary::
:toctree: _autosummary

ACC
GMEM
SMEM
10 changes: 10 additions & 0 deletions docs/jax.experimental.pallas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@

.. automodule:: jax.experimental.pallas

Backends
--------

.. toctree::
:maxdepth: 1

jax.experimental.pallas.mosaic_gpu
jax.experimental.pallas.triton
jax.experimental.pallas.tpu

Classes
-------

Expand Down
16 changes: 16 additions & 0 deletions docs/jax.experimental.pallas.tpu.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
``jax.experimental.pallas.tpu`` module
======================================

.. automodule:: jax.experimental.pallas.tpu

Classes
-------

.. autosummary::
:toctree: _autosummary

Functions
---------

.. autosummary::
:toctree: _autosummary
22 changes: 22 additions & 0 deletions docs/jax.experimental.pallas.triton.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
``jax.experimental.pallas.triton`` module
=========================================

.. automodule:: jax.experimental.pallas.triton

Classes
-------

.. autosummary::
:toctree: _autosummary

TritonCompilerParams

Functions
---------

.. autosummary::
:toctree: _autosummary

approx_tanh
debug_barrier
elementwise_inline_asm
3 changes: 3 additions & 0 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ class GPUCompilerParams(pallas_core.CompilerParams):


class GPUMemorySpace(enum.Enum):
#: Global memory.
GMEM = "gmem"
#: Shared memory.
SMEM = "smem"
#: Registers.
REGS = "regs"

def __str__(self) -> str:
Expand Down
6 changes: 5 additions & 1 deletion jax/experimental/pallas/mosaic_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
from jax._src.pallas.mosaic_gpu.core import TilingTransform
from jax._src.pallas.mosaic_gpu.core import TransposeTransform
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import wait_barrier
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import wgmma
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait

#: Alias of :class:`jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef`.
ACC = WGMMAAccumulatorRef
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`.
GMEM = GPUMemorySpace.GMEM
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM`.
SMEM = GPUMemorySpace.SMEM

0 comments on commit 475a992

Please sign in to comment.