From 9cf952a535518da59cdcecc9145dba287beddca2 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Tue, 8 Oct 2024 15:47:36 -0700 Subject: [PATCH] [Pallas] Add support for runtime checking of grid bounds using checkify. PiperOrigin-RevId: 683791662 --- jax/_src/pallas/pallas_call.py | 83 ++++++++++++++++++++++++++++++++++ tests/pallas/pallas_test.py | 49 ++++++++++++++++++-- 2 files changed, 127 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 6114afd020f4..44dad819bc09 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -993,6 +993,85 @@ def checkify_pallas_kernel_body_jaxpr( body_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals) return checked_jaxpr, out_tree, error_effects +def pallas_call_checkify_oob_grid(error: checkify.Error, + enabled_errors, + args: jax_core.Value, + grid_mapping: GridMapping, + input_output_aliases) -> checkify.Error: + if checkify.OOBError not in enabled_errors: + return error + dynamic_grid_args, args = split_list( + args, [grid_mapping.num_dynamic_grid_bounds] + ) + output_args = _initialize_output_vals(grid_mapping.block_mappings_output, + args, input_output_aliases) + scalars, input_args, _ = split_list( + args, [grid_mapping.num_index_operands, + grid_mapping.num_inputs], + ) + dynamic_grid_args_iter = iter(dynamic_grid_args) + grid = tuple( + a if a is not pallas_core.dynamic_grid_dim + else next(dynamic_grid_args_iter) + for a in grid_mapping.grid + ) + grid_start_indices = (jnp.int32(0),) * len(grid) + if grid: + num_iterations = reduce(jnp.multiply, grid) + else: + # Base case is always one iteration when grid is () + num_iterations = 1 + + is_indexing_dim = [ + tuple(b is pallas_core.mapped for b in bm.block_shape) + for bm in grid_mapping.block_mappings + ] + block_shapes = [ + None if iid is None + else tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) + for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) + ] + # The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch) + # i:int32 is the interation index + # loop_idx: tuple[int32] are the program ids for each grid axis + def cond(carry): + i, *_ = carry + return i < num_iterations + def body(carry): + i, loop_idx = carry + if grid_mapping.local_grid_env is not None: + local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) + else: + local_grid_env = tuple( + pallas_core.GridAxis(idx, b) + for dim, (idx, b) in enumerate(zip(loop_idx, grid)) + if dim not in grid_mapping.vmapped_dims + ) + with pallas_core.grid_env(local_grid_env): + start_indices = [ + None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars) + for bm in grid_mapping.block_mappings] + # We perform a dynamic slice on the i/o blocks, which will be checked by + # checkify for OOB accesses. + map(_maybe_dynamic_slice, start_indices, block_shapes, + [*input_args, *output_args], is_indexing_dim) + return (i + 1, _get_next_indices(grid, loop_idx)) + def f(_): + lax.while_loop( + cond, body, (jnp.int32(0), grid_start_indices) + ) + flat_args, jaxpr_in_tree = jax.tree_util.tree_flatten((jnp.int32(0),)) + wrapped_loop, _ = api_util.flatten_fun_nokwargs( + lu.wrap_init(f), jaxpr_in_tree) + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + avals_in = map(jax_core.get_aval, flat_args) + traced_loop, _, consts, () = pe.trace_to_jaxpr_dynamic( + wrapped_loop, list(avals_in)) + traced_loop = jax_core.ClosedJaxpr(traced_loop, consts) + out_error, _ = checkify.checkify_jaxpr( + traced_loop, checkify.index_checks, error, flat_args) + return out_error + def pallas_call_checkify_rule(error: checkify.Error, enabled_errors, *args: jax_core.Value, @@ -1002,6 +1081,10 @@ def pallas_call_checkify_rule(error: checkify.Error, grid_mapping: GridMapping, out_avals: tuple[jax_core.AbstractValue, ...], **kwargs): + # Check for OOB accesses in the grid. + error = pallas_call_checkify_oob_grid(error, enabled_errors, + args, grid_mapping, + input_output_aliases) # We implement the checkify rule in 4 steps: # 1) First, trace the kernel body to get the expected error shapes. # 2) Checkify the kernel body to obtain a jaxpr with errors as inputs diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 6df31b55f8e7..6c1ae2f423fa 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -2033,11 +2033,12 @@ def _(): np.testing.assert_allclose(out, expected, atol=atol) -class PallasCheckifyInterpretTest(PallasBaseTest): - # TODO(b/346651778): Support non-interpret mode checkify. - INTERPRET = True +class PallasCheckifyTest(PallasBaseTest): + INTERPRET = False def test_no_checkify(self,): + if jtu.test_device_matches(["gpu"]): + self.skipTest("Not supported on GPU.") def kernel(y_ref): y_ref[...] = jnp.zeros_like(y_ref[...]) out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32) @@ -2049,6 +2050,8 @@ def kernel(y_ref): np.testing.assert_allclose(result, jnp.zeros_like(result)) def test_does_not_clobber_previous_error(self,): + if jtu.test_device_matches(["gpu"]): + self.skipTest("Not supported on GPU.") def kernel(y_ref): y_ref[...] = jnp.zeros_like(y_ref[...]) checkify.check(False, "error in kernel") @@ -2067,6 +2070,8 @@ def error_before_call(): @parameterized.parameters((False,), (True,)) def test_trivial_check(self, assert_cond): + if jtu.test_device_matches(["gpu"]): + self.skipTest("Not supported on GPU.") def kernel(x_ref, y_ref): y_ref[...] = x_ref[...] checkify.check(assert_cond, "pallas check failed") @@ -2083,6 +2088,8 @@ def kernel(x_ref, y_ref): np.testing.assert_allclose(result, input) def test_nan_error(self): + if not self.INTERPRET: + self.skipTest("Not supported in non-interpret mode.") def kernel(x_ref, y_ref): y_ref[...] = jnp.log(x_ref[...]) input = jnp.arange(4, dtype=jnp.float32) - 2 @@ -2090,7 +2097,7 @@ def kernel(x_ref, y_ref): pallas_call = self.pallas_call(kernel, out_shape=out_shape) checked_call = checkify.checkify(pallas_call, - errors=checkify.all_checks) + errors=checkify.nan_checks) err, result = checked_call(input) with self.assertRaisesRegex( checkify.JaxRuntimeError, "nan generated by primitive: log"): @@ -2119,6 +2126,8 @@ def kernel(x_ref, y_ref): @parameterized.parameters((5, 0), (8, 3), (4, 3)) def test_checkify_returns_first_error_in_grid( self, num_loops, fail_iteration): + if not self.INTERPRET: + self.skipTest("Not supported in non-interpret mode.") # Check that checkify returns the first error that occurs # TODO(justinfu): This test doesn't make sense on GPU, where threads run # in parallel. Update checkify to return a grid of errors. @@ -2137,12 +2146,42 @@ def kernel(x_ref, _): out_shape=out_shape) checked_call = checkify.checkify(pallas_call, - errors=checkify.all_checks) + errors=checkify.user_checks) err, _ = checked_call(input_arr) with self.assertRaisesRegex( checkify.JaxRuntimeError, f"failed on loop {fail_iteration}"): err.throw() + def test_checkify_on_oob_grid_access(self): + if not self.INTERPRET: + self.skipTest("Not supported in non-interpret mode.") + if config.enable_x64.value: + self.skipTest("Not supported in x64 mode.") + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + input_arr = jnp.arange(18, dtype=jnp.float32) + in_specs = [pl.BlockSpec((8,), lambda x: (x,))] + out_specs = pl.BlockSpec((8,), lambda x: (x,)) + out_shape = jax.ShapeDtypeStruct((18,), dtype=jnp.float32) + pallas_call = self.pallas_call(kernel, + grid=(3,), + in_specs=in_specs, + out_specs=out_specs, + out_shape=out_shape) + + checked_call = checkify.checkify(pallas_call, + errors=checkify.index_checks) + err, result = checked_call(input_arr) + with self.assertRaisesRegex(checkify.JaxRuntimeError, + (r"out-of-bounds indexing for array of shape \(18,\): index 16 " + r"is out of bounds for axis 0 with size 18")): + err.throw() + np.testing.assert_array_equal(result, input_arr) + + +class PallasCheckifyInterpretTest(PallasCheckifyTest): + INTERPRET = True + class PallasCallNamedGridTest(PallasBaseTest): def test_named_grid(self):