Skip to content

Commit

Permalink
[Pallas] Add support for runtime checking of grid bounds using checkify.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683791662
  • Loading branch information
justinjfu authored and Google-ML-Automation committed Oct 8, 2024
1 parent 9748e2a commit 9cf952a
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 5 deletions.
83 changes: 83 additions & 0 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,85 @@ def checkify_pallas_kernel_body_jaxpr(
body_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
return checked_jaxpr, out_tree, error_effects

def pallas_call_checkify_oob_grid(error: checkify.Error,
enabled_errors,
args: jax_core.Value,
grid_mapping: GridMapping,
input_output_aliases) -> checkify.Error:
if checkify.OOBError not in enabled_errors:
return error
dynamic_grid_args, args = split_list(
args, [grid_mapping.num_dynamic_grid_bounds]
)
output_args = _initialize_output_vals(grid_mapping.block_mappings_output,
args, input_output_aliases)
scalars, input_args, _ = split_list(
args, [grid_mapping.num_index_operands,
grid_mapping.num_inputs],
)
dynamic_grid_args_iter = iter(dynamic_grid_args)
grid = tuple(
a if a is not pallas_core.dynamic_grid_dim
else next(dynamic_grid_args_iter)
for a in grid_mapping.grid
)
grid_start_indices = (jnp.int32(0),) * len(grid)
if grid:
num_iterations = reduce(jnp.multiply, grid)
else:
# Base case is always one iteration when grid is ()
num_iterations = 1

is_indexing_dim = [
tuple(b is pallas_core.mapped for b in bm.block_shape)
for bm in grid_mapping.block_mappings
]
block_shapes = [
None if iid is None
else tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings)
]
# The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch)
# i:int32 is the interation index
# loop_idx: tuple[int32] are the program ids for each grid axis
def cond(carry):
i, *_ = carry
return i < num_iterations
def body(carry):
i, loop_idx = carry
if grid_mapping.local_grid_env is not None:
local_grid_env = grid_mapping.local_grid_env(loop_idx, grid)
else:
local_grid_env = tuple(
pallas_core.GridAxis(idx, b)
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
if dim not in grid_mapping.vmapped_dims
)
with pallas_core.grid_env(local_grid_env):
start_indices = [
None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
for bm in grid_mapping.block_mappings]
# We perform a dynamic slice on the i/o blocks, which will be checked by
# checkify for OOB accesses.
map(_maybe_dynamic_slice, start_indices, block_shapes,
[*input_args, *output_args], is_indexing_dim)
return (i + 1, _get_next_indices(grid, loop_idx))
def f(_):
lax.while_loop(
cond, body, (jnp.int32(0), grid_start_indices)
)
flat_args, jaxpr_in_tree = jax.tree_util.tree_flatten((jnp.int32(0),))
wrapped_loop, _ = api_util.flatten_fun_nokwargs(
lu.wrap_init(f), jaxpr_in_tree)
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
avals_in = map(jax_core.get_aval, flat_args)
traced_loop, _, consts, () = pe.trace_to_jaxpr_dynamic(
wrapped_loop, list(avals_in))
traced_loop = jax_core.ClosedJaxpr(traced_loop, consts)
out_error, _ = checkify.checkify_jaxpr(
traced_loop, checkify.index_checks, error, flat_args)
return out_error

def pallas_call_checkify_rule(error: checkify.Error,
enabled_errors,
*args: jax_core.Value,
Expand All @@ -1002,6 +1081,10 @@ def pallas_call_checkify_rule(error: checkify.Error,
grid_mapping: GridMapping,
out_avals: tuple[jax_core.AbstractValue, ...],
**kwargs):
# Check for OOB accesses in the grid.
error = pallas_call_checkify_oob_grid(error, enabled_errors,
args, grid_mapping,
input_output_aliases)
# We implement the checkify rule in 4 steps:
# 1) First, trace the kernel body to get the expected error shapes.
# 2) Checkify the kernel body to obtain a jaxpr with errors as inputs
Expand Down
49 changes: 44 additions & 5 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2033,11 +2033,12 @@ def _():
np.testing.assert_allclose(out, expected, atol=atol)


class PallasCheckifyInterpretTest(PallasBaseTest):
# TODO(b/346651778): Support non-interpret mode checkify.
INTERPRET = True
class PallasCheckifyTest(PallasBaseTest):
INTERPRET = False

def test_no_checkify(self,):
if jtu.test_device_matches(["gpu"]):
self.skipTest("Not supported on GPU.")
def kernel(y_ref):
y_ref[...] = jnp.zeros_like(y_ref[...])
out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32)
Expand All @@ -2049,6 +2050,8 @@ def kernel(y_ref):
np.testing.assert_allclose(result, jnp.zeros_like(result))

def test_does_not_clobber_previous_error(self,):
if jtu.test_device_matches(["gpu"]):
self.skipTest("Not supported on GPU.")
def kernel(y_ref):
y_ref[...] = jnp.zeros_like(y_ref[...])
checkify.check(False, "error in kernel")
Expand All @@ -2067,6 +2070,8 @@ def error_before_call():

@parameterized.parameters((False,), (True,))
def test_trivial_check(self, assert_cond):
if jtu.test_device_matches(["gpu"]):
self.skipTest("Not supported on GPU.")
def kernel(x_ref, y_ref):
y_ref[...] = x_ref[...]
checkify.check(assert_cond, "pallas check failed")
Expand All @@ -2083,14 +2088,16 @@ def kernel(x_ref, y_ref):
np.testing.assert_allclose(result, input)

def test_nan_error(self):
if not self.INTERPRET:
self.skipTest("Not supported in non-interpret mode.")
def kernel(x_ref, y_ref):
y_ref[...] = jnp.log(x_ref[...])
input = jnp.arange(4, dtype=jnp.float32) - 2
out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype)
pallas_call = self.pallas_call(kernel,
out_shape=out_shape)
checked_call = checkify.checkify(pallas_call,
errors=checkify.all_checks)
errors=checkify.nan_checks)
err, result = checked_call(input)
with self.assertRaisesRegex(
checkify.JaxRuntimeError, "nan generated by primitive: log"):
Expand Down Expand Up @@ -2119,6 +2126,8 @@ def kernel(x_ref, y_ref):
@parameterized.parameters((5, 0), (8, 3), (4, 3))
def test_checkify_returns_first_error_in_grid(
self, num_loops, fail_iteration):
if not self.INTERPRET:
self.skipTest("Not supported in non-interpret mode.")
# Check that checkify returns the first error that occurs
# TODO(justinfu): This test doesn't make sense on GPU, where threads run
# in parallel. Update checkify to return a grid of errors.
Expand All @@ -2137,12 +2146,42 @@ def kernel(x_ref, _):
out_shape=out_shape)

checked_call = checkify.checkify(pallas_call,
errors=checkify.all_checks)
errors=checkify.user_checks)
err, _ = checked_call(input_arr)
with self.assertRaisesRegex(
checkify.JaxRuntimeError, f"failed on loop {fail_iteration}"):
err.throw()

def test_checkify_on_oob_grid_access(self):
if not self.INTERPRET:
self.skipTest("Not supported in non-interpret mode.")
if config.enable_x64.value:
self.skipTest("Not supported in x64 mode.")
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...]
input_arr = jnp.arange(18, dtype=jnp.float32)
in_specs = [pl.BlockSpec((8,), lambda x: (x,))]
out_specs = pl.BlockSpec((8,), lambda x: (x,))
out_shape = jax.ShapeDtypeStruct((18,), dtype=jnp.float32)
pallas_call = self.pallas_call(kernel,
grid=(3,),
in_specs=in_specs,
out_specs=out_specs,
out_shape=out_shape)

checked_call = checkify.checkify(pallas_call,
errors=checkify.index_checks)
err, result = checked_call(input_arr)
with self.assertRaisesRegex(checkify.JaxRuntimeError,
(r"out-of-bounds indexing for array of shape \(18,\): index 16 "
r"is out of bounds for axis 0 with size 18")):
err.throw()
np.testing.assert_array_equal(result, input_arr)


class PallasCheckifyInterpretTest(PallasCheckifyTest):
INTERPRET = True


class PallasCallNamedGridTest(PallasBaseTest):
def test_named_grid(self):
Expand Down

0 comments on commit 9cf952a

Please sign in to comment.