From 04f4e64107171b6502ea8e804f470656777b12e3 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 14 Aug 2024 14:36:41 -0700 Subject: [PATCH] [Pallas] Add lowering for threefry PRNG. PiperOrigin-RevId: 663058435 --- jax/_src/pallas/core.py | 1 + jax/_src/pallas/mosaic/BUILD | 7 +++ jax/_src/pallas/mosaic/lowering.py | 73 ++++++++++++++++++++++---- jax/_src/pallas/mosaic/random.py | 6 +++ tests/pallas/ops_test.py | 38 ++++++++++++++ tests/pallas/tpu_pallas_random_test.py | 32 +++++++++++ 6 files changed, 147 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 0ff463562355..7eeaf58f2d85 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -792,6 +792,7 @@ def _convert_block_spec_to_block_mapping( grid: GridMappingGrid, mapped_dims: tuple[int, ...], ) -> BlockMapping: + array_aval = jax_core.physical_aval(array_aval) if block_spec is no_block_spec: block_spec = BlockSpec(None, None) return block_spec.to_block_mapping( diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index ae76a00a6c17..f52ba9ddd6cd 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -57,9 +57,16 @@ py_library( name = "primitives", srcs = ["primitives.py"], deps = [ + ":core", "//jax", "//jax:core", + "//jax:dtypes", "//jax:mlir", + "//jax:pretty_printer", + "//jax:tree_util", + "//jax:typing", + "//jax:util", + "//jax/_src/pallas", ], ) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index bac1db773ee4..d8396f206ee4 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -55,6 +55,7 @@ from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import error_handling from jax._src.pallas.mosaic import primitives as tpu_primitives +from jax._src.pallas.mosaic import random as pl_random from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import primitives as state_primitives @@ -201,10 +202,13 @@ def aval_to_ir_type(aval, return ir.MemRefType.get((), sem_type, memory_space=memspace) if dtypes.issubdtype(aval.dtype, dtypes.prng_key): shape = aval.dtype._impl.key_shape - if memory_space is None: - memory_space = TPUMemorySpace.SMEM - if memory_space != TPUMemorySpace.SMEM: - raise ValueError(f"PRNG keys must be stored in SMEM. Got {memory_space}") + if pl_random.is_pallas_impl(aval.dtype._impl): + if memory_space is None: + memory_space = TPUMemorySpace.SMEM + if memory_space != TPUMemorySpace.SMEM: + raise ValueError( + f"PRNG keys must be stored in SMEM. Got {memory_space}" + ) memspace = _memory_space_to_mosaic_attribute(memory_space) return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)), memory_space=memspace) @@ -481,7 +485,8 @@ def err_details(): "only blocks having the same block shape as the array shape " "and a trivial index_map (returning all 0s)." + err_details()) - unmapped_bs = [1 if bs is pallas_core.mapped else bs for bs in bm.block_shape] + unmapped_bs = [ + 1 if bs is pallas_core.mapped else bs for bs in bm.block_shape] bs0, as0 = unmapped_bs[-1], bm.array_shape_dtype.shape[-1] if rank >= 2: bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2] @@ -1131,7 +1136,9 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): ref_type = ir.MemRefType(ref.type) is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space" (aval_out,) = ctx.avals_out - if isinstance(aval_out.dtype, prng.KeyTy): + if isinstance(aval_out.dtype, prng.KeyTy) and pl_random.is_pallas_impl( + aval_out.dtype._impl + ): if not is_smem_load: raise ValueError("PRNG keys must be loaded from SMEM. Did you set " "the memory space to TPUMemorySpace.SMEM in the " @@ -2914,8 +2921,13 @@ def random_bits_lowering(ctx, keys, *, bit_width, shape): assert bit_width == 32, "Only 32-bit PRNG supported." aval, = ctx.avals_in impl = aval.dtype._impl - bits_lowering = lower_fun( - impl.random_bits, multiple_results=False) + _proxy_fn = impl.random_bits + if not pl_random.is_pallas_impl(impl): + def new_lowering(key, bit_width, shape): + key = jax.random.key_data(key).astype(jnp.uint32) + return impl.random_bits(key, bit_width, shape) + _proxy_fn = new_lowering + bits_lowering = lower_fun(_proxy_fn, multiple_results=False) return bits_lowering(ctx, keys, bit_width=bit_width, shape=shape) lowering_rules[prng.random_bits_p] = random_bits_lowering @@ -2930,7 +2942,10 @@ def random_fold_in_lowering(ctx, keys, msgs): def random_unwrap_lowering(ctx, key): - del ctx + keys_aval = ctx.avals_in[0] + impl = keys_aval.dtype._impl + if not pl_random.is_pallas_impl(impl): + return key assert isinstance(key, KeyScalarBundle) # Convert to a vector. if tuple(key.key_shape) != (1, 1): @@ -2947,7 +2962,9 @@ def random_unwrap_lowering(ctx, key): def random_wrap_lowering(ctx, key_data, *, impl): - del ctx, impl + del ctx + if not pl_random.is_pallas_impl(impl): + return key_data if isinstance(key_data.type, ir.VectorType): # If the key data lives in vregs, need to unpack it to sregs. key_data_list = [] @@ -2970,6 +2987,42 @@ def random_wrap_lowering(ctx, key_data, *, impl): lowering_rules[prng.random_wrap_p] = random_wrap_lowering +def _threefry2x32_lowering(ctx, k1, k2, m1, m2): + def _lower_fun(k1, k2, m1, m2): + with jax.named_scope("threefry2x32"): + res = prng._threefry2x32_lowering(k1, k2, m1, m2, use_rolled_loops=False) + return res + + threefry_lowering = lower_fun(_lower_fun, multiple_results=True) + return threefry_lowering(ctx, k1, k2, m1, m2) + + +lowering_rules[prng.threefry2x32_p] = _threefry2x32_lowering + + +def _iota_2x32_shape_lowering(ctx, *, shape): + total_elements = np.prod(shape) + if total_elements > np.iinfo(jnp.int32).max: + raise NotImplementedError(f"Iota with >{np.iinfo(jnp.int32).max} items.") + + def _lower_fun(shape): + iota_data = jnp.zeros(shape, dtype=jnp.int32) + multiplier = 1 + for dim in range(len(shape)-1, -1, -1): + counts_lo = lax.broadcasted_iota( + dtype=jnp.int32, shape=shape, dimension=dim + ) + iota_data += counts_lo * multiplier + multiplier *= shape[dim] + counts_hi = jnp.zeros(shape, dtype=jnp.int32) + return counts_hi, iota_data + + iota_lowering = lower_fun(_lower_fun, multiple_results=True) + return iota_lowering(ctx, shape=shape) + + +lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering + # Lowering for shard_map # Technically this is not a lowering rule, but a discharge rule. When we use diff --git a/jax/_src/pallas/mosaic/random.py b/jax/_src/pallas/mosaic/random.py index 68a4fe508917..16dc5ee1fe56 100644 --- a/jax/_src/pallas/mosaic/random.py +++ b/jax/_src/pallas/mosaic/random.py @@ -46,6 +46,12 @@ def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray: pallas_key_data = (jax.vmap(generate_key))(key) return jax_api_random.wrap_key_data(pallas_key_data, impl="pallas_tpu") + +def is_pallas_impl(impl: jax_prng.PRNGImpl) -> bool: + """Returns True if the PRNGImpl is a Pallas-specific implementation.""" + return impl == tpu_key_impl or impl == tpu_internal_stateful_impl + + def _seed_func(seed: jnp.int32): seed_data = jnp.zeros(tpu_key_impl.key_shape, dtype=jnp.int32) return (seed_data + seed).astype(jnp.uint32) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 979870ff351f..fd3e2acbc507 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -993,6 +993,44 @@ def kernel(x_ref, y_ref, o_ref): np.testing.assert_allclose(f(x, y), kernel(x, y)) + @parameterized.named_parameters( + (f"{fn.__name__}_{dtype}", fn, dtype) + for args in BINARY_OPS + for fn, dtype in itertools.product(*args) + ) + def test_binary_scalar(self, f, dtype): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Test only supported on TPU.") + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: + self.skipTest("16-bit types are not supported on TPU") + # TODO: skipped due to https://github.com/jax-ml/jax/issues/24027 + if ( + jtu.test_device_matches(["tpu"]) + and f == jnp.remainder + and not self.INTERPRET + ): + self.skipTest("jnp.remainder on TPU is only supported in interpret mode") + + # TODO: skipped due to https://github.com/jax-ml/jax/issues/23972 + if jtu.test_device_matches(["tpu"]) and dtype == "uint32": + self.skipTest("Not supported on TPU") + + @functools.partial( + self.pallas_call, + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_shape=jax.ShapeDtypeStruct((1,), dtype), grid=1 + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[0] = f(x_ref[0], y_ref[0]) + + x = jnp.array([1,]).astype(dtype) + y = jnp.array([18,]).astype(dtype) + + np.testing.assert_allclose(f(x, y), kernel(x, y)) + @parameterized.parameters( ((8, 4), jnp.int32, 0), ((8, 16), jnp.float32, 1), diff --git a/tests/pallas/tpu_pallas_random_test.py b/tests/pallas/tpu_pallas_random_test.py index e3d43125c9ab..cbef2852ad05 100644 --- a/tests/pallas/tpu_pallas_random_test.py +++ b/tests/pallas/tpu_pallas_random_test.py @@ -220,5 +220,37 @@ def body(key_ref, o_ref): np.testing.assert_array_equal(result_16x128, result_32x256) +class ThreefryTest(parameterized.TestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Need TPU devices") + super().setUp() + + @parameterized.parameters( + ((8, 128),), + ((32, 256),), + ((4, 16, 128),), + ) + def test_uniform_matches_jax_threefry(self, shape): + def body(key_ref, o_ref): + o_ref[...] = jax_random.uniform( + key_ref[0, ...], shape=o_ref[...].shape, minval=0.0, maxval=1.0 + ) + + threefry_key = jax_random.key(0, impl="threefry2x32").reshape((1,)) + o_shape = jax.ShapeDtypeStruct(shape, jnp.float32) + with jax.threefry_partitionable(True): + result = pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], + out_shape=o_shape, + )(threefry_key) + jax_result = jax_random.uniform( + threefry_key[0], shape=o_shape.shape, minval=0.0, maxval=1.0 + ) + np.testing.assert_array_equal(result, jax_result) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())