From 46e65b5982895741ee254e92bf5c7404d9ff5b03 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 8 Oct 2024 11:13:04 +0100 Subject: [PATCH] [pallas] Added API docs for Triton and Mosaic GPU backends I've left the TPU backend docs a stub for now. Hopefully, someone working on Pallas TPU can fill them in later. --- docs/jax.experimental.pallas.mosaic_gpu.rst | 41 +++++++++++++++++++++ docs/jax.experimental.pallas.rst | 18 ++++++++- docs/jax.experimental.pallas.tpu.rst | 16 ++++++++ docs/jax.experimental.pallas.triton.rst | 22 +++++++++++ docs/jax.experimental.rst | 12 +++--- jax/_src/pallas/mosaic_gpu/core.py | 3 ++ jax/experimental/pallas/mosaic_gpu.py | 3 ++ 7 files changed, 107 insertions(+), 8 deletions(-) create mode 100644 docs/jax.experimental.pallas.mosaic_gpu.rst create mode 100644 docs/jax.experimental.pallas.tpu.rst create mode 100644 docs/jax.experimental.pallas.triton.rst diff --git a/docs/jax.experimental.pallas.mosaic_gpu.rst b/docs/jax.experimental.pallas.mosaic_gpu.rst new file mode 100644 index 000000000000..76884a12211b --- /dev/null +++ b/docs/jax.experimental.pallas.mosaic_gpu.rst @@ -0,0 +1,41 @@ +``jax.experimental.pallas.mosaic_gpu`` module +============================================= + +.. automodule:: jax.experimental.pallas.mosaic_gpu + +Classes +------- + +.. autosummary:: + :toctree: _autosummary + + Barrier + GPUBlockSpec + GPUCompilerParams + GPUMemorySpace + TilingTransform + TransposeTransform + WGMMAAccumulatorRef + +Functions +--------- + +.. autosummary:: + :toctree: _autosummary + + copy_gmem_to_smem + copy_smem_to_gmem + wait_barrier + wait_smem_to_gmem + wgmma + wgmma_wait + +Aliases +------- + +.. autosummary:: + :toctree: _autosummary + + ACC + GMEM + SMEM diff --git a/docs/jax.experimental.pallas.rst b/docs/jax.experimental.pallas.rst index 1cddbc177e6f..c945f939fa4d 100644 --- a/docs/jax.experimental.pallas.rst +++ b/docs/jax.experimental.pallas.rst @@ -3,6 +3,16 @@ .. automodule:: jax.experimental.pallas +Backends +-------- + +.. toctree:: + :maxdepth: 1 + + jax.experimental.pallas.mosaic_gpu + jax.experimental.pallas.triton + jax.experimental.pallas.tpu + Classes ------- @@ -36,7 +46,11 @@ Functions atomic_min atomic_or atomic_xchg - + atomic_xor + broadcast_to debug_print - + dot + max_contiguous + multiple_of run_scoped + when diff --git a/docs/jax.experimental.pallas.tpu.rst b/docs/jax.experimental.pallas.tpu.rst new file mode 100644 index 000000000000..ae4e2c2253e4 --- /dev/null +++ b/docs/jax.experimental.pallas.tpu.rst @@ -0,0 +1,16 @@ +``jax.experimental.pallas.tpu`` module +====================================== + +.. automodule:: jax.experimental.pallas.tpu + +Classes +------- + +.. autosummary:: + :toctree: _autosummary + +Functions +--------- + +.. autosummary:: + :toctree: _autosummary \ No newline at end of file diff --git a/docs/jax.experimental.pallas.triton.rst b/docs/jax.experimental.pallas.triton.rst new file mode 100644 index 000000000000..76b0896ccf17 --- /dev/null +++ b/docs/jax.experimental.pallas.triton.rst @@ -0,0 +1,22 @@ +``jax.experimental.pallas.triton`` module +========================================= + +.. automodule:: jax.experimental.pallas.triton + +Classes +------- + +.. autosummary:: + :toctree: _autosummary + + TritonCompilerParams + +Functions +--------- + +.. autosummary:: + :toctree: _autosummary + + approx_tanh + debug_barrier + elementwise_inline_asm \ No newline at end of file diff --git a/docs/jax.experimental.rst b/docs/jax.experimental.rst index 4f7afd787286..7672c94c6b52 100644 --- a/docs/jax.experimental.rst +++ b/docs/jax.experimental.rst @@ -16,17 +16,17 @@ Experimental Modules jax.experimental.array_api jax.experimental.checkify - jax.experimental.pjit - jax.experimental.sparse - jax.experimental.jet - jax.experimental.custom_partitioning - jax.experimental.multihost_utils jax.experimental.compilation_cache + jax.experimental.custom_partitioning + jax.experimental.jet jax.experimental.key_reuse jax.experimental.mesh_utils + jax.experimental.multihost_utils + jax.experimental.pallas + jax.experimental.pjit jax.experimental.serialize_executable jax.experimental.shard_map - jax.experimental.pallas + jax.experimental.sparse Experimental APIs ----------------- diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 7eb51ebbf77d..7ca10d01a581 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -59,8 +59,11 @@ class GPUCompilerParams(pallas_core.CompilerParams): class GPUMemorySpace(enum.Enum): + #: Global memory. GMEM = "gmem" + #: Shared memory. SMEM = "smem" + #: Registers. REGS = "regs" def __str__(self) -> str: diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 451955451b58..15a1772d988a 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -23,6 +23,7 @@ from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace from jax._src.pallas.mosaic_gpu.core import TilingTransform from jax._src.pallas.mosaic_gpu.core import TransposeTransform +from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem @@ -31,5 +32,7 @@ from jax._src.pallas.mosaic_gpu.primitives import wgmma from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`. GMEM = GPUMemorySpace.GMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM`. SMEM = GPUMemorySpace.SMEM