Skip to content

Commit

Permalink
[pallas:mosaic_gpu] pl.run_scoped now supports scoped barriers
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684449776
  • Loading branch information
superbobry authored and Google-ML-Automation committed Oct 10, 2024
1 parent 94abaf4 commit 70ee8e1
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 36 deletions.
161 changes: 127 additions & 34 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

from __future__ import annotations

from collections.abc import Sequence
import collections
from collections.abc import MutableMapping, MutableSequence, Sequence
import dataclasses
import functools
import itertools as it
import math
from typing import Any, cast
from typing import Any, Protocol, cast

import jax
from jax import lax
Expand Down Expand Up @@ -59,44 +60,101 @@
partial = functools.partial
SMEM = gpu_core.SMEM

_smem_estimators = {}

@dataclasses.dataclass(kw_only=True, frozen=True)
class Resources:
smem_scratch_bytes: int
barriers: collections.Counter[mgpu.Barrier] = dataclasses.field(
default_factory=collections.Counter
)

def __add__(self, other: Resources) -> Resources:
# TODO(slebedev): Optimize this.
#
# At the moment, if we have run_scoped(b1) followed by run_scoped(b2)
# we will allocate two barriers, even though one would be enough.
return Resources(
smem_scratch_bytes=self.smem_scratch_bytes + other.smem_scratch_bytes,
barriers=self.barriers + other.barriers,
)

def __or__(self, other: Resources) -> Resources:
return Resources(
smem_scratch_bytes=max(
self.smem_scratch_bytes, other.smem_scratch_bytes
),
barriers=self.barriers | other.barriers,
)


class ResourceEstimator(Protocol):

def _regiter_smem_estimator(primitive: jax_core.Primitive):
def __call__(self, *args: Any, **params: Any) -> Resources:
...


_resource_estimators: dict[jax_core.Primitive, ResourceEstimator] = {}


def _register_resource_estimator(primitive: jax_core.Primitive):
def deco(fn):
_smem_estimators[primitive] = fn
_resource_estimators[primitive] = fn
return fn

return deco


def _estimate_smem_scratch_bytes(jaxpr: jax_core.Jaxpr) -> int:
"""Estimates the amount of SMEM scratch bytes required by the kernel."""
max_used = 0
def _estimate_resources(jaxpr: jax_core.Jaxpr) -> Resources:
"""Estimates the resources required by the kernel."""
rs = Resources(smem_scratch_bytes=0)
for eqn in jaxpr.eqns:
# TODO(slebedev): Add support for other primitives, notably control flow.
rule = _smem_estimators.get(eqn.primitive)
rule = _resource_estimators.get(eqn.primitive)
if rule is None:
# Assume that unsupported primitives are neutral wrt SMEM usage.
# Assume that unsupported primitives are neutral wrt resource usage.
continue
max_used = max(
max_used, rule(*(invar.aval for invar in eqn.invars), **eqn.params)
)
return max_used
rs |= rule(*(invar.aval for invar in eqn.invars), **eqn.params)
return rs


@_register_resource_estimator(lax.cond_p)
def _cond_resource_estimator(*args, branches) -> int:
del args # Unused.
return functools.reduce(
lambda a, b: a | b,
(_estimate_resources(branch.jaxpr) for branch in branches),
)

@_regiter_smem_estimator(primitives.run_scoped_p)
def _run_scoped_smem_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int:

@_register_resource_estimator(lax.scan_p)
def _scan_resource_estimator(*args, jaxpr: jax_core.ClosedJaxpr, **params) -> int:
del args, params # Unused.
return _estimate_resources(jaxpr)


@_register_resource_estimator(primitives.run_scoped_p)
def _run_scoped_resource_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int:
del consts # Unused.
in_avals = (v.aval.inner_aval for v in jaxpr.invars)
return sum(math.prod(aval.shape) * aval.dtype.itemsize for aval in in_avals)
smem_scratch_bytes = 0
barriers = []
for v in jaxpr.invars:
aval = v.aval
if isinstance(aval.dtype, gpu_core.BarrierType):
barriers.append(mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape))
else:
smem_scratch_bytes += math.prod(aval.shape) * aval.dtype.itemsize
rs = Resources(
smem_scratch_bytes=smem_scratch_bytes,
barriers=collections.Counter(barriers),
)
return rs + _estimate_resources(jaxpr)


@_regiter_smem_estimator(lax.reduce_sum_p)
def _reduce_sum_smem_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int:
@_register_resource_estimator(lax.reduce_sum_p)
def _reduce_sum_resource_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int:
if axes != (0,):
raise NotImplementedError("No support for axes other than 0 yet")
return 4 * x_aval.dtype.itemsize
return Resources(smem_scratch_bytes=4 * x_aval.dtype.itemsize)


@dataclasses.dataclass
Expand All @@ -106,7 +164,21 @@ class ModuleContext:
program_ids: Sequence[ir.Value] | None
approx_math: bool
runtime_smem: ir.Value # ir.MemRefType
smem_used_bytes: int = 0
smem_used_bytes: int
runtime_barriers: MutableMapping[
mgpu.Barrier, MutableSequence[mgpu.BarrierRef]
]

def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef:
"""Reserves a barrier.
Raises:
RuntimeError: If the barrier is already reserved.
"""
available = self.runtime_barriers.get(barrier, [])
if not available:
raise RuntimeError(f"Barrier {barrier} is already reserved")
return available.pop()

# TODO(cperivol): Only return the shapes and figure out the sizes when freeing.
def scratch_view(
Expand Down Expand Up @@ -352,7 +424,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
in_buffers_smem, out_buffers_smem = util.split_list(
buffers_smem, [grid_mapping.num_inputs]
)
barriers, *extra_barriers = barriers
barriers, runtime_barriers, extra_barriers = barriers

parallel_count = it.count()
program_ids_template = [
Expand All @@ -367,9 +439,21 @@ def make_program_ids(step: ir.Value):
step = arith_dialect.index_cast(ir.IntegerType.get_signless(32), step)
return [step if pid is None else pid for pid in program_ids_template]

grouped_barriers = collections.defaultdict(list)
for barrier, barrier_ref in zip(
sorted(rs.barriers.elements()), runtime_barriers
):
grouped_barriers[barrier].append(barrier_ref)
module_ctx = ModuleContext(
name_and_src_info.name, grid_mapping, None, approx_math, runtime_smem
name_and_src_info.name,
grid_mapping,
None,
approx_math,
runtime_smem,
smem_used_bytes=0,
runtime_barriers=grouped_barriers,
)
del runtime_smem, grouped_barriers, runtime_barriers

smem_scratch_it = iter(scratch_buffers_smem)
scratch_buffers_template = []
Expand Down Expand Up @@ -611,6 +695,7 @@ def _(step, carry):
"All scratch operands must be SMEM references or accumulators (ACC),"
f" but got: {scratch_avals}"
)
rs = _estimate_resources(jaxpr)
extra_barriers = [
mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape)
for aval in scratch_avals
Expand All @@ -624,7 +709,7 @@ def _(step, carry):
]
smem_scratch_bytes = compiler_params.get("smem_scratch_bytes")
if smem_scratch_bytes is None:
smem_scratch_bytes = _estimate_smem_scratch_bytes(jaxpr)
smem_scratch_bytes = rs.smem_scratch_bytes
extra_smem_scratch.append(
jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8)
)
Expand All @@ -641,7 +726,8 @@ def _(step, carry):
*extra_smem_scratch,
(
mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps),
*extra_barriers,
[*sorted(rs.barriers.elements())],
extra_barriers,
),
),
module_name=name_and_src_info.name,
Expand Down Expand Up @@ -979,21 +1065,28 @@ def _run_scoped_lowering_rule(
input_refs = []
bytes_allocated = 0
should_discharge = []
for a in jaxpr.invars:
a = a.aval
if isinstance(a, gpu_core.WGMMAAbstractAccumulatorRef):
mlir_dtype = mlir.dtype_to_ir_type(a.dtype)
input_refs.append(mgpu.WGMMAAccumulator.zero(*a.shape, mlir_dtype))
for v in jaxpr.invars:
aval = v.aval
if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef):
mlir_dtype = mlir.dtype_to_ir_type(aval.dtype)
input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, mlir_dtype))
should_discharge.append(True)
elif a.memory_space == gpu_core.SMEM:
elif isinstance(aval.dtype, gpu_core.BarrierType):
input_refs.append(
ctx.module_ctx.reserve_barrier(
mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape)
)
)
should_discharge.append(False)
elif aval.memory_space == gpu_core.SMEM:
ref_bytes, [input_ref] = ctx.module_ctx.scratch_view(
[jax.ShapeDtypeStruct(shape=a.shape, dtype=a.dtype)]
[jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)]
)
bytes_allocated += ref_bytes
input_refs.append(input_ref)
should_discharge.append(False)
else:
raise ValueError(f"Can't convert to ref: {a}")
raise ValueError(f"Can't convert to ref: {aval}")

if any(should_discharge):
# We convert consts to args, because we only have ir.Values and
Expand Down
21 changes: 19 additions & 2 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,24 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
x = jnp.arange(128).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)

def test_copy_gmem_to_smem_in_run_scoped(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
)
def kernel(x_ref_gmem, o_ref):
def body(barrier_ref):
def inner_body(scratch_ref):
plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier=barrier_ref)
plgpu.wait_barrier(barrier_ref)
o_ref[...] = scratch_ref[...] + 1
pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32))
pl.run_scoped(body, plgpu.Barrier(num_arrivals=1))

x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)

def test_add_doubled_sum(self):
@functools.partial(
pl.pallas_call,
Expand Down Expand Up @@ -375,7 +393,7 @@ def kernel(x_ref, o_ref):

self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output())

def test_scoped_allocation(self):
def test_run_scoped(self):
def kernel(x_ref, o_ref):
def body(tmp_ref):
self.assertEqual(tmp_ref.shape, (8, 128))
Expand Down Expand Up @@ -611,7 +629,6 @@ def scope(acc_ref):
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)


def test_input_output_aliases(self):
# Note that we're writing to the input pointer, which should alias b_ptr.
def kernel(a_ref, b_ref):
Expand Down

0 comments on commit 70ee8e1

Please sign in to comment.