Skip to content

Commit

Permalink
[Pallas] Add lowering for threefry PRNG.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 663058435
  • Loading branch information
justinjfu authored and Google-ML-Automation committed Oct 9, 2024
1 parent f52b016 commit 8ac0543
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 10 deletions.
7 changes: 7 additions & 0 deletions jax/_src/pallas/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,16 @@ py_library(
name = "primitives",
srcs = ["primitives.py"],
deps = [
":core",
"//jax",
"//jax:core",
"//jax:dtypes",
"//jax:mlir",
"//jax:pretty_printer",
"//jax:tree_util",
"//jax:typing",
"//jax:util",
"//jax/_src/pallas",
],
)

Expand Down
73 changes: 63 additions & 10 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import error_handling
from jax._src.pallas.mosaic import primitives as tpu_primitives
from jax._src.pallas.mosaic import random as pl_random
from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import primitives as state_primitives
Expand Down Expand Up @@ -201,10 +202,13 @@ def aval_to_ir_type(aval,
return ir.MemRefType.get((), sem_type, memory_space=memspace)
if dtypes.issubdtype(aval.dtype, dtypes.prng_key):
shape = aval.dtype._impl.key_shape
if memory_space is None:
memory_space = TPUMemorySpace.SMEM
if memory_space != TPUMemorySpace.SMEM:
raise ValueError(f"PRNG keys must be stored in SMEM. Got {memory_space}")
if pl_random.is_pallas_impl(aval.dtype._impl):
if memory_space is None:
memory_space = TPUMemorySpace.SMEM
if memory_space != TPUMemorySpace.SMEM:
raise ValueError(
f"PRNG keys must be stored in SMEM. Got {memory_space}"
)
memspace = _memory_space_to_mosaic_attribute(memory_space)
return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)),
memory_space=memspace)
Expand Down Expand Up @@ -481,7 +485,8 @@ def err_details():
"only blocks having the same block shape as the array shape "
"and a trivial index_map (returning all 0s)." + err_details())

unmapped_bs = [1 if bs is pallas_core.mapped else bs for bs in bm.block_shape]
unmapped_bs = [
1 if bs is pallas_core.mapped else bs for bs in bm.block_shape]
bs0, as0 = unmapped_bs[-1], bm.array_shape_dtype.shape[-1]
if rank >= 2:
bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2]
Expand Down Expand Up @@ -1131,7 +1136,9 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
ref_type = ir.MemRefType(ref.type)
is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space<smem>"
(aval_out,) = ctx.avals_out
if isinstance(aval_out.dtype, prng.KeyTy):
if isinstance(aval_out.dtype, prng.KeyTy) and pl_random.is_pallas_impl(
aval_out.dtype._impl
):
if not is_smem_load:
raise ValueError("PRNG keys must be loaded from SMEM. Did you set "
"the memory space to TPUMemorySpace.SMEM in the "
Expand Down Expand Up @@ -2918,8 +2925,13 @@ def random_bits_lowering(ctx, keys, *, bit_width, shape):
assert bit_width == 32, "Only 32-bit PRNG supported."
aval, = ctx.avals_in
impl = aval.dtype._impl
bits_lowering = lower_fun(
impl.random_bits, multiple_results=False)
_proxy_fn = impl.random_bits
if not pl_random.is_pallas_impl(impl):
def new_lowering(key, bit_width, shape):
key = jax.random.key_data(key).astype(jnp.uint32)
return impl.random_bits(key, bit_width, shape)
_proxy_fn = new_lowering
bits_lowering = lower_fun(_proxy_fn, multiple_results=False)
return bits_lowering(ctx, keys, bit_width=bit_width, shape=shape)
lowering_rules[prng.random_bits_p] = random_bits_lowering

Expand All @@ -2934,7 +2946,10 @@ def random_fold_in_lowering(ctx, keys, msgs):


def random_unwrap_lowering(ctx, key):
del ctx
keys_aval = ctx.avals_in[0]
impl = keys_aval.dtype._impl
if not pl_random.is_pallas_impl(impl):
return key
assert isinstance(key, KeyScalarBundle)
# Convert to a vector.
if tuple(key.key_shape) != (1, 1):
Expand All @@ -2951,7 +2966,9 @@ def random_unwrap_lowering(ctx, key):


def random_wrap_lowering(ctx, key_data, *, impl):
del ctx, impl
del ctx
if not pl_random.is_pallas_impl(impl):
return key_data
if isinstance(key_data.type, ir.VectorType):
# If the key data lives in vregs, need to unpack it to sregs.
key_data_list = []
Expand All @@ -2974,6 +2991,42 @@ def random_wrap_lowering(ctx, key_data, *, impl):
lowering_rules[prng.random_wrap_p] = random_wrap_lowering


def _threefry2x32_lowering(ctx, k1, k2, m1, m2):
def _lower_fun(k1, k2, m1, m2):
with jax.named_scope("threefry2x32"):
res = prng._threefry2x32_lowering(k1, k2, m1, m2, use_rolled_loops=False)
return res

threefry_lowering = lower_fun(_lower_fun, multiple_results=True)
return threefry_lowering(ctx, k1, k2, m1, m2)


lowering_rules[prng.threefry2x32_p] = _threefry2x32_lowering


def _iota_2x32_shape_lowering(ctx, *, shape):
total_elements = np.prod(shape)
if total_elements > np.iinfo(jnp.int32).max:
raise NotImplementedError(f"Iota with >{np.iinfo(jnp.int32).max} items.")

def _lower_fun(shape):
iota_data = jnp.zeros(shape, dtype=jnp.int32)
multiplier = 1
for dim in range(len(shape)-1, -1, -1):
counts_lo = lax.broadcasted_iota(
dtype=jnp.int32, shape=shape, dimension=dim
)
iota_data += counts_lo * multiplier
multiplier *= shape[dim]
counts_hi = jnp.zeros(shape, dtype=jnp.int32)
return counts_hi, iota_data

iota_lowering = lower_fun(_lower_fun, multiple_results=True)
return iota_lowering(ctx, shape=shape)


lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering

# Lowering for shard_map

# Technically this is not a lowering rule, but a discharge rule. When we use
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/pallas/mosaic/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray:
pallas_key_data = (jax.vmap(generate_key))(key)
return jax_api_random.wrap_key_data(pallas_key_data, impl="pallas_tpu")


def is_pallas_impl(impl: jax_prng.PRNGImpl) -> bool:
"""Returns True if the PRNGImpl is a Pallas-specific implementation."""
return impl == tpu_key_impl or impl == tpu_internal_stateful_impl


def _seed_func(seed: jnp.int32):
seed_data = jnp.zeros(tpu_key_impl.key_shape, dtype=jnp.int32)
return (seed_data + seed).astype(jnp.uint32)
Expand Down
38 changes: 38 additions & 0 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,44 @@ def kernel(x_ref, y_ref, o_ref):

np.testing.assert_allclose(f(x, y), kernel(x, y))

@parameterized.named_parameters(
(f"{fn.__name__}_{dtype}", fn, dtype)
for args in BINARY_OPS
for fn, dtype in itertools.product(*args)
)
def test_binary_scalar(self, f, dtype):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Test only supported on TPU.")
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")
# TODO: skipped due to https://github.com/jax-ml/jax/issues/24027
if (
jtu.test_device_matches(["tpu"])
and f == jnp.remainder
and not self.INTERPRET
):
self.skipTest("jnp.remainder on TPU is only supported in interpret mode")

# TODO: skipped due to https://github.com/jax-ml/jax/issues/23972
if jtu.test_device_matches(["tpu"]) and dtype == "uint32":
self.skipTest("Not supported on TPU")

@functools.partial(
self.pallas_call,
in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
out_shape=jax.ShapeDtypeStruct((1,), dtype), grid=1
)
def kernel(x_ref, y_ref, o_ref):
o_ref[0] = f(x_ref[0], y_ref[0])

x = jnp.array([1,]).astype(dtype)
y = jnp.array([18,]).astype(dtype)

np.testing.assert_allclose(f(x, y), kernel(x, y))

@parameterized.parameters(
((8, 4), jnp.int32, 0),
((8, 16), jnp.float32, 1),
Expand Down
34 changes: 34 additions & 0 deletions tests/pallas/tpu_pallas_random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,5 +220,39 @@ def body(key_ref, o_ref):
np.testing.assert_array_equal(result_16x128, result_32x256)


class ThreefryTest(parameterized.TestCase):

def setUp(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Need TPU devices")
super().setUp()

@parameterized.parameters(
((8, 128),),
((32, 256),),
((4, 16, 128),),
)
def test_uniform_matches_jax_threefry(self, shape):
def body(key_ref, o_ref):
key = jax.random.wrap_key_data(key_ref[0, ...], impl='threefry2x32')
o_ref[...] = jax_random.uniform(
key, shape=o_ref[...].shape, minval=0.0, maxval=1.0
)

threefry_key = jax_random.key(0, impl="threefry2x32").reshape((1,))
o_shape = jax.ShapeDtypeStruct(shape, jnp.float32)
with jax.threefry_partitionable(True):
# TODO(justinfu): support passing keys into VMEM.
result = pl.pallas_call(
body,
in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)],
out_shape=o_shape,
)(jax.random.key_data(threefry_key))
jax_result = jax_random.uniform(
threefry_key[0], shape=o_shape.shape, minval=0.0, maxval=1.0
)
np.testing.assert_array_equal(result, jax_result)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 8ac0543

Please sign in to comment.