From a52d18781e12171c69d33ab90412fbb7a25f486a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 11 Dec 2023 12:03:48 -0800 Subject: [PATCH] Add experimental static key reuse checking --- docs/jax.experimental.key_reuse.rst | 13 + docs/jax.experimental.rst | 1 + jax/BUILD | 1 + jax/__init__.py | 1 + jax/_src/config.py | 6 + jax/_src/core.py | 7 + jax/_src/lax/control_flow/loops.py | 9 +- jax/_src/pjit.py | 5 + jax/_src/state/types.py | 4 +- jax/_src/test_util.py | 1 + jax/experimental/jax2tf/jax2tf.py | 5 + .../jax2tf/tests/back_compat_test.py | 3 +- .../tests/jax_primitives_coverage_test.py | 3 +- jax/experimental/jax2tf/tests/tf_test_util.py | 3 +- jax/experimental/key_reuse/__init__.py | 45 + jax/experimental/key_reuse/_common.py | 77 ++ jax/experimental/key_reuse/_core.py | 49 ++ jax/experimental/key_reuse/_forwarding.py | 294 +++++++ jax/experimental/key_reuse/_simple.py | 267 ++++++ tests/BUILD | 6 + tests/batching_test.py | 3 +- tests/core_test.py | 12 + tests/key_reuse_test.py | 801 ++++++++++++++++++ tests/random_lax_test.py | 7 +- tests/shape_poly_test.py | 3 + 25 files changed, 1617 insertions(+), 9 deletions(-) create mode 100644 docs/jax.experimental.key_reuse.rst create mode 100644 jax/experimental/key_reuse/__init__.py create mode 100644 jax/experimental/key_reuse/_common.py create mode 100644 jax/experimental/key_reuse/_core.py create mode 100644 jax/experimental/key_reuse/_forwarding.py create mode 100644 jax/experimental/key_reuse/_simple.py create mode 100644 tests/key_reuse_test.py diff --git a/docs/jax.experimental.key_reuse.rst b/docs/jax.experimental.key_reuse.rst new file mode 100644 index 000000000000..27975a9f153e --- /dev/null +++ b/docs/jax.experimental.key_reuse.rst @@ -0,0 +1,13 @@ +``jax.experimental.key_reuse`` module +===================================== + +.. automodule:: jax.experimental.key_reuse + +API +--- + +.. autosummary:: + :toctree: _autosummary + + unconsumed_copy + KeyReuseError diff --git a/docs/jax.experimental.rst b/docs/jax.experimental.rst index 95c89fc9cd58..15c9077463db 100644 --- a/docs/jax.experimental.rst +++ b/docs/jax.experimental.rst @@ -24,6 +24,7 @@ Experimental Modules jax.experimental.custom_partitioning jax.experimental.multihost_utils jax.experimental.compilation_cache + jax.experimental.key_reuse Experimental APIs ----------------- diff --git a/jax/BUILD b/jax/BUILD index 2346f9c7ab68..9bd928d044f2 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -212,6 +212,7 @@ py_library_providing_imports_info( "_src/scipy/**/*.py", "_src/state/**/*.py", "_src/third_party/**/*.py", + "experimental/key_reuse/**/*.py", "image/**/*.py", "interpreters/**/*.py", "lax/**/*.py", diff --git a/jax/__init__.py b/jax/__init__.py index 76e16e5a99fc..2e5d39433244 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -52,6 +52,7 @@ from jax._src.config import ( config as config, enable_checks as enable_checks, + enable_key_reuse_checks as enable_key_reuse_checks, check_tracer_leaks as check_tracer_leaks, checking_leaks as checking_leaks, enable_custom_prng as enable_custom_prng, diff --git a/jax/_src/config.py b/jax/_src/config.py index 1c858f27f697..12d29fc55a82 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -775,6 +775,12 @@ def update_thread_local_jit_state(**kw): default=False, help='Turn on invariant checking for JAX internals. Makes things slower.') +enable_key_reuse_checks = define_bool_state( + name='jax_enable_key_reuse_checks', + default=False, + help="Turn on experimental key reuse checking." +) + check_tracer_leaks = define_bool_state( name='jax_check_tracer_leaks', default=False, diff --git a/jax/_src/core.py b/jax/_src/core.py index 1aec4b432ae9..c2e586a2ff9e 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2766,6 +2766,13 @@ def ctx_factory(): msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str]) raise JaxprTypeError(msg) from None + # Run key reuse checker after validating jaxpr: + if config.enable_key_reuse_checks.value: + # Import here to avoid circular imports + from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error + check_key_reuse_jaxpr(jaxpr) + + def _check_jaxpr( ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]], jaxpr: Jaxpr diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 9eaa375772f7..04deab04ba4b 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -372,12 +372,19 @@ def cond_fun(vals): def body_fun(vals): [i], carry, ys = split_list(vals, [1, num_carry]) i_ = length - i - 1 if reverse else i - x = _map(partial(_dynamic_index_array, i_), x_avals, xs) + # TODO(jakevdp)[key-reuse]: this key reuse logic is not quite right, + # because the scan body may consume any keys within it. + # Import here to avoid circular imports + from jax.experimental import key_reuse + xs_unconsumed = _map(key_reuse.unconsumed_copy, xs) + x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed) out_flat = f_impl(*consts, *carry, *x) carry_out, y_updates = split_list(out_flat, [num_carry]) ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates) return [i + 1] + carry_out + ys_out + # TODO(jakevdp)[key-reuse]: mark xs consumed here if f_impl consumes them. + ys_init = _map(partial(_empty_array, length), y_avals) if length == 0: return init + ys_init diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ee01d1c39c12..95cac8afc392 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -937,6 +937,11 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths): if not config.dynamic_shapes.value: jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths()) + if config.enable_key_reuse_checks.value: + # Import here to avoid circular imports + from jax.experimental.key_reuse._core import check_key_reuse_jaxpr + check_key_reuse_jaxpr(jaxpr) + if any(isinstance(c, core.Tracer) for c in consts): closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) final_consts = consts diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index e7845f41a615..fda7ef10d99a 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -110,13 +110,13 @@ def join(self, other): @property def shape(self): if not isinstance(self.inner_aval, core.ShapedArray): - raise ValueError(f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`.") + raise AttributeError(f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`.") return self.inner_aval.shape @property def dtype(self): if not isinstance(self.inner_aval, core.UnshapedArray): - raise ValueError(f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`.") + raise AttributeError(f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`.") return self.inner_aval.dtype @core.aval_property diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 9befc8a361b5..b8594fcb412b 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -913,6 +913,7 @@ class JaxTestCase(parameterized.TestCase): """Base class for JAX tests including numerical checks and boilerplate.""" _default_config = { 'jax_enable_checks': True, + 'jax_enable_key_reuse_checks': True, 'jax_numpy_dtype_promotion': 'strict', 'jax_numpy_rank_promotion': 'raise', 'jax_traceback_filtering': 'off', diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 616a5e954ccf..7a8319721430 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -65,6 +65,7 @@ from jax._src.lax import windowed_reductions as lax_windowed_reductions from jax._src.lib import xla_client from jax._src.numpy.ufuncs import logaddexp +from jax.experimental.key_reuse._common import unconsumed_copy_p import tensorflow as tf # type: ignore[import] @@ -1503,8 +1504,12 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "tridiagonal", "eigh_jacobi", "platform_index", + "assert_consumed_value", + "consume", ] +tf_impl[unconsumed_copy_p] = lambda x: x + tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient tf_impl[ad_util.zeros_like_p] = tf.zeros_like diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index d866e0c2b89e..fda558a2f9ed 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -65,7 +65,8 @@ config.parse_flags_with_absl() -@jtu.with_config(jax_legacy_prng_key='allow') +@jtu.with_config(jax_legacy_prng_key='allow', + jax_enable_key_reuse_checks=False) class CompatTest(bctu.CompatTestBase): def test_dummy(self): # Tests the testing mechanism. Let this test run on all platforms diff --git a/jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py b/jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py index 451e04e355f8..ec0123324e60 100644 --- a/jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py +++ b/jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py @@ -41,7 +41,8 @@ from jax._src.internal_test_util import test_harnesses -@jtu.with_config(jax_legacy_prng_key='allow') +@jtu.with_config(jax_legacy_prng_key='allow', + jax_enable_key_reuse_checks=False) class JaxPrimitiveTest(jtu.JaxTestCase): # This test runs for all primitive harnesses. For each primitive "xxx" the diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 6751f7b86472..a27a2668a50c 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -156,7 +156,8 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence, # TODO(necula): clean up the test harnesses to not require these flags @jtu.with_config(jax_numpy_rank_promotion="allow", jax_numpy_dtype_promotion='standard', - jax_legacy_prng_key="allow") + jax_legacy_prng_key="allow", + jax_enable_key_reuse_checks=False) class JaxToTfTestCase(jtu.JaxTestCase): # We want most tests to use the maximum available version, from the locally # installed tfxla module and export. diff --git a/jax/experimental/key_reuse/__init__.py b/jax/experimental/key_reuse/__init__.py new file mode 100644 index 000000000000..c7ca0f9d0b62 --- /dev/null +++ b/jax/experimental/key_reuse/__init__.py @@ -0,0 +1,45 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Experimental Key Reuse Checking +------------------------------- + +This module contains **experimental** functionality for detecting re-use of random +keys within JAX programs. It is under active development and the APIs here are likely +to change. + +Key reuse checking can be enabled on `jit`-compiled functions using the +:func:`jax.enable_key_reuse_checks` configuration:: + + >>> import jax + >>> @jax.jit + ... def f(key): + ... return jax.random.uniform(key) + jax.random.normal(key) + ... + >>> key = jax.random.key(0) + >>> with jax.enable_key_reuse_checks(): + ... f(key) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + KeyReuseError: In random_bits, key values a are already consumed. + +This flag can also be set globally if you wish to enagle key reuse checks in +every JIT-compiled function. +""" + +from jax.experimental.key_reuse._common import ( + unconsumed_copy as unconsumed_copy, + KeyReuseError as KeyReuseError, +) diff --git a/jax/experimental/key_reuse/_common.py b/jax/experimental/key_reuse/_common.py new file mode 100644 index 000000000000..f5aeb9eec226 --- /dev/null +++ b/jax/experimental/key_reuse/_common.py @@ -0,0 +1,77 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import NamedTuple, Union +from jax import core +from jax.interpreters import batching, mlir +import numpy as np + + +class Sink(NamedTuple): + idx: int + mask: Union[bool, np.ndarray] = True + + +class Source(NamedTuple): + idx: int + mask: Union[bool, np.ndarray] = True + + +class KeyReuseSignature(NamedTuple): + sinks: list[Sink] + sources: list[Source] + + +class KeyReuseError(RuntimeError): + pass + +consume_p = core.Primitive("consume") +consume_p.def_impl(lambda x: x) +consume_p.def_abstract_eval(lambda x: x) +batching.defvectorized(consume_p) +mlir.register_lowering( + consume_p, + mlir.lower_fun(lambda x: x, multiple_results=False)) + +def consume(key): + """Consume the key and return a consumed copy.""" + return consume_p.bind(key) + +unconsumed_copy_p = core.Primitive("unconsumed_copy") +unconsumed_copy_p.def_impl(lambda x: x) +unconsumed_copy_p.def_abstract_eval(lambda x: x) +batching.defvectorized(unconsumed_copy_p) +mlir.register_lowering( + unconsumed_copy_p, + mlir.lower_fun(lambda x: x, multiple_results=False)) + +def unconsumed_copy(key): + """Return a copy of key marked as unconsumed.""" + return unconsumed_copy_p.bind(key) + +assert_consumed_value_p = core.Primitive("assert_consumed_value") +assert_consumed_value_p.def_impl(lambda x, *, value: x) +assert_consumed_value_p.def_abstract_eval(lambda x, *, value: x) +batching.defvectorized(assert_consumed_value_p) +mlir.register_lowering( + assert_consumed_value_p, + mlir.lower_fun(lambda x, *, value: x, multiple_results=False)) + +def assert_unconsumed(key): + """Assert that a key is unconsumed""" + assert_consumed_value_p.bind(key, value=False) + +def assert_consumed(key, value=True): + """Assert that a key is consumed""" + assert_consumed_value_p.bind(key, value=value) diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py new file mode 100644 index 000000000000..a6c1409c9045 --- /dev/null +++ b/jax/experimental/key_reuse/_core.py @@ -0,0 +1,49 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Optional, Union + +from jax import core +from jax.experimental.key_reuse import _forwarding +from jax.experimental.key_reuse import _simple +import numpy as np + +# TODO(jakevdp) fix this +KeyReuseSignature = Any + + +def check_key_reuse(fun: Callable[..., Any], /, *args: Any, + use_forwarding: bool = True) -> KeyReuseSignature: + """Function to statically check key reuse.""" + if use_forwarding: + return _forwarding.check_key_reuse(fun, *args) + else: + return _simple.check_key_reuse(fun, *args) + + +def check_key_reuse_jaxpr(jaxpr: core.Jaxpr, *, use_forwarding: bool = True): + """Check the jaxpr for key reuse.""" + get_jaxpr_type_signature(jaxpr, use_forwarding=use_forwarding) + + +def get_jaxpr_type_signature( + jaxpr: core.Jaxpr, *, + consumed_inputs: Optional[list[Union[bool, np.ndarray]]] = None, + use_forwarding: bool = True, + ) -> KeyReuseSignature: + """Parse the jaxpr to determine key reuse signature""" + if use_forwarding: + return _forwarding.get_jaxpr_type_signature(jaxpr, consumed_inputs) + else: + return _simple.get_jaxpr_type_signature(jaxpr, consumed_inputs) diff --git a/jax/experimental/key_reuse/_forwarding.py b/jax/experimental/key_reuse/_forwarding.py new file mode 100644 index 000000000000..4b4c682f2466 --- /dev/null +++ b/jax/experimental/key_reuse/_forwarding.py @@ -0,0 +1,294 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import defaultdict +from functools import reduce +from typing import Any, Callable, NamedTuple, Optional, Union + +import jax +from jax import core +from jax import lax +from jax import tree_util +from jax._src import api_util +from jax._src import linear_util as lu +from jax._src import pjit +from jax._src import prng +from jax._src import random +from jax._src import util +from jax._src.debugging import debug_callback_p +from jax._src.interpreters import partial_eval as pe + +from jax.experimental.key_reuse._common import ( + consume_p, unconsumed_copy_p, assert_consumed_value_p, KeyReuseError, + Sink, Source, KeyReuseSignature +) +import numpy as np + +class Forward(NamedTuple): + in_idx: int + out_idx: int + + +class KeyReuseSignatureWithForwards(NamedTuple): + sinks: list[Sink] + sources: list[Source] + forwards: list[Forward] = [] + +# The behavior of most primitives can be described via simple signatures. +key_reuse_signatures: dict[core.Primitive, KeyReuseSignatureWithForwards] = {} + +key_reuse_signatures[consume_p] = KeyReuseSignatureWithForwards([Sink(0)], [], [Forward(0, 0)]) +key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignatureWithForwards([], [Source(0)]) +key_reuse_signatures[prng.random_bits_p] = KeyReuseSignatureWithForwards([Sink(0)], []) +key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)]) +key_reuse_signatures[prng.random_seed_p] = KeyReuseSignatureWithForwards([], [Source(0)]) +key_reuse_signatures[prng.random_split_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)]) +key_reuse_signatures[random.random_gamma_p] = KeyReuseSignatureWithForwards([Sink(0)], []) +key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignatureWithForwards([], [], [Forward(0, 0)]) +key_reuse_signatures[lax.copy_p] = KeyReuseSignatureWithForwards([], [], [Forward(0, 0)]) +key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignatureWithForwards([], [], [Forward(0, 0)]) +key_reuse_signatures[lax.device_put_p] = KeyReuseSignatureWithForwards([], [], [Forward(0, 0)]) +key_reuse_signatures[lax.reshape_p] = KeyReuseSignatureWithForwards([], [], [Forward(0, 0)]) +key_reuse_signatures[lax.squeeze_p] = KeyReuseSignatureWithForwards([], [], [Forward(0, 0)]) +key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignatureWithForwards([], [Source(0)], []) +key_reuse_signatures[prng.random_unwrap_p] = KeyReuseSignatureWithForwards([Sink(0)], [], []) +key_reuse_signatures[debug_callback_p] = KeyReuseSignatureWithForwards([], []) +key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignatureWithForwards([], [], [Forward(0, 0)]) +key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignatureWithForwards([], [], []) + +# Rules which require more dynamic logic. +key_reuse_signatures_dynamic: dict[core.Primitive, Callable[..., KeyReuseSignatureWithForwards]] = {} + +# The default signature will Sink all key inputs, and not Source any. +def unknown_signature(eqn, args_consumed): + def is_key(var: core.Atom): + return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key) + return KeyReuseSignatureWithForwards( + sinks=[Sink(idx, True) for idx, var in enumerate(eqn.invars) if is_key(var)], + sources=[], + ) + +def get_jaxpr_type_signature( + jaxpr: core.Jaxpr, + consumed_inputs: Optional[list[Union[bool, np.ndarray]]] = None, + forwarded_inputs: Optional[dict[int, int]] = None, + ) -> KeyReuseSignatureWithForwards: + """Parse the jaxpr to determine key reuse signature""" + consumed: dict[core.Atom, Union[bool, np.ndarray]] = {} + forwards: dict[core.Atom, core.Atom] = {} # map forwarded outputs to inputs. + + def resolve_forwards(var: core.Atom) -> core.Atom: + if not forwards: + return var + for _ in range(len(forwards) + 1): + if isinstance(var, core.Literal): + return var + if var in forwards: + var = forwards[var] + else: + return var + raise ValueError("forwarding cycle detected") + + def is_key(var: core.Atom): + return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key) + + def sink(var: core.Atom, mask=True): + if not is_key(var): + return + var = resolve_forwards(var) + assert not isinstance(var, core.Literal) + if np.any(np.logical_and(consumed.get(var, False), mask)): + return True + consumed[var] = np.logical_or(consumed.get(var, False), mask) + + + def source(var: core.Atom, mask=False): + if not is_key(var): + return + var = resolve_forwards(var) + assert not isinstance(var, core.Literal) + consumed[var] = mask + + def is_consumed(var: core.Atom): + var = resolve_forwards(var) + if isinstance(var, core.Literal): + return False + return consumed.get(var, False) + + if forwarded_inputs: + for i, j in forwarded_inputs.items(): + forwards[jaxpr.invars[i]] = jaxpr.invars[j] + + if consumed_inputs: + for var, mask in util.safe_zip(jaxpr.invars, consumed_inputs): + if not isinstance(var, core.Literal): + source(var, mask) + + for eqn in jaxpr.eqns: + if eqn.primitive in key_reuse_signatures: + signature = key_reuse_signatures[eqn.primitive] + elif eqn.primitive in key_reuse_signatures_dynamic: + args_consumed = [is_consumed(var) for var in eqn.invars] + signature = key_reuse_signatures_dynamic[eqn.primitive](eqn, args_consumed) + else: + args_consumed = [is_consumed(var) for var in eqn.invars] + signature = unknown_signature(eqn, args_consumed) + for in_idx, out_idx in signature.forwards: + forwards[eqn.outvars[out_idx]] = eqn.invars[in_idx] + + for snk in signature.sinks: + if sink(eqn.invars[snk.idx], snk.mask): + raise KeyReuseError(f"In {eqn.primitive}, key values {eqn.invars[snk.idx]} are already consumed.\n" + f"eqn: {eqn}\njaxpr:\n{jaxpr}") + for var in eqn.outvars: + if not isinstance(var, core.Literal) and var not in forwards: + source(var, True) # consumed unless in a Source. + for src in signature.sources: + source(eqn.outvars[src.idx]) + + return KeyReuseSignatureWithForwards( + sinks=[Sink(i, consumed[v]) for i, v in enumerate(jaxpr.invars) + if is_key(v) and np.any(consumed.get(v, False))], + sources=[Source(i) for i, v in enumerate(jaxpr.outvars) + if is_key(v) and resolve_forwards(v) not in jaxpr.invars and not consumed.get(v, False)], + forwards=[Forward(jaxpr.invars.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type] + for idx_out, outvar in enumerate(jaxpr.outvars) + if is_key(outvar) and resolve_forwards(outvar) in jaxpr.invars] + ) + + +def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> KeyReuseSignatureWithForwards: + """Function to statically check key reuse.""" + args_flat, in_tree = tree_util.tree_flatten(args) + in_avals_flat = [core.get_aval(arg) for arg in args_flat] + wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) + return get_jaxpr_type_signature(jaxpr) + + +#---------------------------------------------------------------------------------- +# key reuse rules for particular primitives: + +def _slice_signature(eqn, args_consumed): + del args_consumed # unused here + in_aval = eqn.invars[0].aval + start_indices = eqn.params['start_indices'] + limit_indices = eqn.params['limit_indices'] + strides = eqn.params['strides'] or (1,) * len(start_indices) + idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides)) + mask = np.zeros(in_aval.shape, dtype=bool) + mask[idx] = True + return KeyReuseSignatureWithForwards([Sink(0, mask)], [Source(0)]) + +key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature + +def _pjit_key_type_signature(eqn, args_consumed): + jaxpr = eqn.params['jaxpr'] + forwarded_inputs = {i: eqn.invars.index(var) for i, var in enumerate(eqn.invars) + if var in eqn.invars[:i]} + return get_jaxpr_type_signature(jaxpr.jaxpr, consumed_inputs=args_consumed, + forwarded_inputs=forwarded_inputs) + +key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature + +def _assert_consumed_value_key_type_signature(eqn, args_consumed): + actual = args_consumed[0] + expected = eqn.params['value'] + if not np.all(actual == expected): + if np.all(expected): + raise AssertionError(f"Expected key to be consumed in {eqn}") + elif not np.any(expected): + raise AssertionError(f"Expected key to not be consumed in {eqn}") + else: + raise AssertionError(f"Expected {expected}, got {actual} in {eqn}") + return KeyReuseSignatureWithForwards([], [], [Forward(0, 0)]) + +key_reuse_signatures_dynamic[assert_consumed_value_p] = _assert_consumed_value_key_type_signature + +def _cond_key_type_signature(eqn, args_consumed): + signatures = [get_jaxpr_type_signature(branch.jaxpr, consumed_inputs=args_consumed[1:]) + for branch in eqn.params['branches']] + sinks = defaultdict(lambda: []) + sources = defaultdict(lambda: []) + for sig in signatures: + for sink in sig.sinks: + sinks[sink.idx].append(sink.mask) + for source in sig.sources: + sources[source.idx].append(source.mask) + + combined_sinks = [Sink(i + 1, reduce(np.logical_or, m)) for i, m in sinks.items()] + combined_sources = [Source(i + 1, reduce(np.logical_and, m)) for i, m in sources.items()] + combined_forwards = [Forward(f.in_idx + 1, f.out_idx) for f in + set.intersection(*(set(sig.forwards) for sig in signatures))] + return KeyReuseSignatureWithForwards(combined_sinks, combined_sources, combined_forwards) + +key_reuse_signatures_dynamic[lax.cond_p] = _cond_key_type_signature + +def _scan_key_type_signature(eqn, args_consumed): + jaxpr = eqn.params['jaxpr'].jaxpr + num_consts = eqn.params['num_consts'] + num_carry = eqn.params['num_carry'] + length = eqn.params['length'] + signature = get_jaxpr_type_signature(jaxpr, args_consumed) + + # scan body should not consume key in constants + if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts): + raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed: {signature=}") + + # scan carry should only consume keys that are sourced on output. + carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks if 0 <= s.idx - num_consts < num_carry} + carry_sources = {s.idx: s.mask for s in signature.sources if 0 <= s.idx < num_carry} + if carry_sinks.keys() != carry_sources.keys(): # TODO(jakevdp): check that masks match + raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed: {signature=}") + return signature + +key_reuse_signatures_dynamic[jax.lax.scan_p] = _scan_key_type_signature + +def _while_key_type_signature(eqn, args_consumed): + cond_jaxpr = eqn.params['cond_jaxpr'].jaxpr + cond_nconsts = eqn.params['cond_nconsts'] + body_jaxpr = eqn.params['body_jaxpr'].jaxpr + body_nconsts = eqn.params['body_nconsts'] + + # TODO(jakevdp): pass args_consumed here? + cond_signature = get_jaxpr_type_signature(cond_jaxpr) + body_signature = get_jaxpr_type_signature(body_jaxpr) + + # Error if there are sinks among consts. + if any(np.any(s.mask) for s in cond_signature.sinks if s.idx < cond_nconsts): + raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: " + f"{cond_signature=}") + if any(np.any(s.mask) for s in body_signature.sinks if s.idx < body_nconsts): + raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: " + f"{body_signature=}") + + # carry should only consume keys that are sourced on output. + body_carry_sinks = {s.idx - body_nconsts: s.mask for s in body_signature.sinks if s.idx >= body_nconsts} + cond_carry_sinks = {s.idx - cond_nconsts: s.mask for s in cond_signature.sinks if s.idx >= cond_nconsts} + carry_sources = {s.idx: s.mask for s in body_signature.sources} + # TODO(jakevdp): check masks at each index? + if not (cond_carry_sinks.keys() <= carry_sources.keys()): + raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: " + f"{cond_signature=}") + if not (body_carry_sinks.keys() <= carry_sources.keys()): + raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: " + f"{body_signature=}") + if body_carry_sinks.keys() & cond_carry_sinks.keys(): + raise KeyReuseError("while_loop cond and body functions both use the same key: " + f"{cond_signature=} {body_signature=}") + return body_signature + +key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature diff --git a/jax/experimental/key_reuse/_simple.py b/jax/experimental/key_reuse/_simple.py new file mode 100644 index 000000000000..b06630e35ce0 --- /dev/null +++ b/jax/experimental/key_reuse/_simple.py @@ -0,0 +1,267 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import defaultdict +from functools import reduce +from typing import Any, Callable, NamedTuple, Optional, Union + +import jax +from jax import core +from jax import lax +from jax import tree_util +from jax._src import api_util +from jax._src import linear_util as lu +from jax._src import pjit +from jax._src import prng +from jax._src import random +from jax._src import util +from jax._src.debugging import debug_callback_p +from jax._src.interpreters import partial_eval as pe + +from jax.experimental.key_reuse._common import ( + consume_p, unconsumed_copy_p, assert_consumed_value_p, KeyReuseError, + Sink, Source, KeyReuseSignature +) +import numpy as np + +# The behavior of most primitives can be described via simple signatures. +key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {} + +key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], []) +key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignature([], [Source(0)]) +key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], []) +key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)]) +key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature([], [Source(0)]) +key_reuse_signatures[prng.random_split_p] = KeyReuseSignature([Sink(0)], [Source(0)]) +key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature([Sink(0)], []) +key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature([Sink(0)], [Source(0)]) +key_reuse_signatures[lax.copy_p] = KeyReuseSignature([Sink(0)], [Source(0)]) +key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature([Sink(0)], [Source(0)]) +key_reuse_signatures[lax.device_put_p] = KeyReuseSignature([Sink(0)], [Source(0)]) +key_reuse_signatures[lax.reshape_p] = KeyReuseSignature([Sink(0)], [Source(0)]) +key_reuse_signatures[lax.squeeze_p] = KeyReuseSignature([Sink(0)], [Source(0)]) +key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature([], [Source(0)]) +key_reuse_signatures[prng.random_unwrap_p] = KeyReuseSignature([Sink(0)], []) +key_reuse_signatures[debug_callback_p] = KeyReuseSignature([], []) +key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignature([Sink(0)], [Source(0)]) +key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature([], []) + +# Rules which require more dynamic logic. +key_reuse_signatures_dynamic: dict[core.Primitive, Callable[..., KeyReuseSignature]] = {} + + +# The default signature will Sink all key inputs, and not Source any. +def unknown_signature(eqn, args_consumed): + def is_key(var: core.Atom): + return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key) + return KeyReuseSignature( + sinks=[Sink(idx, True) for idx, var in enumerate(eqn.invars) if is_key(var)], + sources=[], + ) + + +def get_jaxpr_type_signature( + jaxpr: core.Jaxpr, + consumed_inputs: Optional[list[Union[bool, np.ndarray]]] = None, + ) -> KeyReuseSignature: + """Parse the jaxpr to determine key reuse signature""" + consumed: dict[core.Atom, Union[bool, np.ndarray]] = {} + + def is_key(var: core.Atom): + return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key) + + def sink(var: core.Atom, mask=True): + if not is_key(var): + return + assert not isinstance(var, core.Literal) + if np.any(np.logical_and(consumed.get(var, False), mask)): + return True + consumed[var] = np.logical_or(consumed.get(var, False), mask) + + def source(var: core.Atom, mask=False): + if not is_key(var): + return + assert not isinstance(var, core.Literal) + consumed[var] = mask + + def is_consumed(var: core.Atom): + if isinstance(var, core.Literal): + return False + return consumed.get(var, False) + + if consumed_inputs: + for var, mask in util.safe_zip(jaxpr.invars, consumed_inputs): + if not isinstance(var, core.Literal): + source(var, mask) + + for eqn in jaxpr.eqns: + if eqn.primitive in key_reuse_signatures: + signature = key_reuse_signatures[eqn.primitive] + elif eqn.primitive in key_reuse_signatures_dynamic: + args_consumed = [is_consumed(var) for var in eqn.invars] + signature = key_reuse_signatures_dynamic[eqn.primitive](eqn, args_consumed) + else: + args_consumed = [is_consumed(var) for var in eqn.invars] + signature = unknown_signature(eqn, args_consumed) + + for snk in signature.sinks: + if sink(eqn.invars[snk.idx], snk.mask): + raise KeyReuseError(f"In {eqn.primitive}, key values {eqn.invars[snk.idx]} are already consumed.\n" + f"eqn: {eqn}\njaxpr:\n{jaxpr}") + for var in eqn.outvars: + if not isinstance(var, core.Literal): + source(var, True) # consumed unless in a Source. + for src in signature.sources: + source(eqn.outvars[src.idx]) + + forwards = [v for v in jaxpr.outvars + if is_key(v) and v in jaxpr.invars and not np.any(consumed.get(v, False))] + sinks = [v for v in jaxpr.invars if is_key(v) and np.any(consumed.get(v, False))] + sources = [v for v in jaxpr.outvars if is_key(v) and not np.any(consumed.get(v, False))] + return KeyReuseSignature( + sinks=[ + Sink(i, True if v in forwards else consumed[v]) + for i, v in enumerate(jaxpr.invars) + if v in forwards or v in sinks + ], + sources=[ + Source(i) for i, v in enumerate(jaxpr.outvars) + if (v in forwards or v in sources) + and v not in jaxpr.outvars[:i] # Only source the first of duplicate return values + ], + ) + + +def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> KeyReuseSignature: + """Function to statically check key reuse.""" + args_flat, in_tree = tree_util.tree_flatten(args) + in_avals_flat = [core.get_aval(arg) for arg in args_flat] + wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) + jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) + return get_jaxpr_type_signature(jaxpr) + + +#---------------------------------------------------------------------------------- +# key reuse rules for particular primitives: + +def _slice_signature(eqn, args_consumed): + del args_consumed # unused here + in_aval = eqn.invars[0].aval + start_indices = eqn.params['start_indices'] + limit_indices = eqn.params['limit_indices'] + strides = eqn.params['strides'] or (1,) * len(start_indices) + idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides)) + mask = np.zeros(in_aval.shape, dtype=bool) + mask[idx] = True + return KeyReuseSignature([Sink(0, mask)], [Source(0)]) + +key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature + +def _pjit_key_type_signature(eqn, args_consumed): + jaxpr = eqn.params['jaxpr'] + non_literal_invars = [v for v in eqn.invars if not isinstance(v, core.Literal)] + if len(set(non_literal_invars)) != len(non_literal_invars): + raise ValueError(f"pjit with duplicate inputs: {eqn.invars=}") + return get_jaxpr_type_signature(jaxpr.jaxpr, consumed_inputs=args_consumed) + +key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature + +def _assert_consumed_value_key_type_signature(eqn, args_consumed): + actual = args_consumed[0] + expected = eqn.params['value'] + if not np.all(actual == expected): + if np.all(expected): + raise AssertionError(f"Expected key to be consumed in {eqn}") + elif not np.any(expected): + raise AssertionError(f"Expected key to not be consumed in {eqn}") + else: + raise AssertionError(f"Expected {expected}, got {actual} in {eqn}") + return KeyReuseSignature([], []) + +key_reuse_signatures_dynamic[assert_consumed_value_p] = _assert_consumed_value_key_type_signature + +def _cond_key_type_signature(eqn, args_consumed): + signatures = [get_jaxpr_type_signature(branch.jaxpr, consumed_inputs=args_consumed[1:]) + for branch in eqn.params['branches']] + sinks = defaultdict(lambda: []) + sources = defaultdict(lambda: []) + for sig in signatures: + for sink in sig.sinks: + sinks[sink.idx].append(sink.mask) + for source in sig.sources: + sources[source.idx].append(source.mask) + + combined_sinks = [Sink(i + 1, reduce(np.logical_or, m)) for i, m in sinks.items()] + combined_sources = [Source(i + 1, reduce(np.logical_and, m)) for i, m in sources.items()] + return KeyReuseSignature(combined_sinks, combined_sources) + +key_reuse_signatures_dynamic[lax.cond_p] = _cond_key_type_signature + +def _scan_key_type_signature(eqn, args_consumed): + jaxpr = eqn.params['jaxpr'].jaxpr + num_consts = eqn.params['num_consts'] + num_carry = eqn.params['num_carry'] + length = eqn.params['length'] + signature = get_jaxpr_type_signature(jaxpr, args_consumed) + + # scan body should not consume key in constants + if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts): + raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed: {signature=}") + + # scan carry should only consume keys that are sourced on output. + carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks if 0 <= s.idx - num_consts < num_carry} + carry_sources = {s.idx: s.mask for s in signature.sources if 0 <= s.idx < num_carry} + if carry_sinks.keys() != carry_sources.keys(): # TODO(jakevdp): check that masks match + raise KeyReuseError(f"scan body function leads to key reuse when repeatedly executed: {signature=}") + return signature + +key_reuse_signatures_dynamic[jax.lax.scan_p] = _scan_key_type_signature + +def _while_key_type_signature(eqn, args_consumed): + cond_jaxpr = eqn.params['cond_jaxpr'].jaxpr + cond_nconsts = eqn.params['cond_nconsts'] + body_jaxpr = eqn.params['body_jaxpr'].jaxpr + body_nconsts = eqn.params['body_nconsts'] + + # TODO(jakevdp): pass args_consumed here? + cond_signature = get_jaxpr_type_signature(cond_jaxpr) + body_signature = get_jaxpr_type_signature(body_jaxpr) + + # Error if there are sinks among consts. + if any(np.any(s.mask) for s in cond_signature.sinks if s.idx < cond_nconsts): + raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: " + f"{cond_signature=}") + if any(np.any(s.mask) for s in body_signature.sinks if s.idx < body_nconsts): + raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: " + f"{body_signature=}") + + # carry should only consume keys that are sourced on output. + body_carry_sinks = {s.idx - body_nconsts: s.mask for s in body_signature.sinks if s.idx >= body_nconsts} + cond_carry_sinks = {s.idx - cond_nconsts: s.mask for s in cond_signature.sinks if s.idx >= cond_nconsts} + carry_sources = {s.idx: s.mask for s in body_signature.sources} + # TODO(jakevdp): check masks at each index? + if not (cond_carry_sinks.keys() <= carry_sources.keys()): + raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: " + f"{cond_signature=}") + if not (body_carry_sinks.keys() <= carry_sources.keys()): + raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: " + f"{body_signature=}") + if body_carry_sinks.keys() & cond_carry_sinks.keys(): + raise KeyReuseError("while_loop cond and body functions both use the same key: " + f"{cond_signature=} {body_signature=}") + return body_signature + +key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature diff --git a/tests/BUILD b/tests/BUILD index dd28fae5035f..0af19d1bc424 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1112,6 +1112,12 @@ jax_test( ] + py_deps("tensorflow_core"), ) + +jax_test( + name = "key_reuse_test", + srcs = ["key_reuse_test.py"], +) + jax_test( name = "x64_context_test", srcs = ["x64_context_test.py"], diff --git a/tests/batching_test.py b/tests/batching_test.py index ff03ef318752..e2376ee2eec7 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -960,7 +960,8 @@ def body_fn(uk): u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key)) return u - print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash + with jax.enable_key_reuse_checks(False): + print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash def testEmptyTuples(self): # Ensure there is no crash when a vectorized input contains empty tuples. diff --git a/tests/core_test.py b/tests/core_test.py index 8e38694645f5..b3e3e1cf2b8e 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -784,6 +784,18 @@ def g(x): return x with self.assertRaisesRegex(TypeError, "inconsistently typed as"): core.check_jaxpr(jaxpr) + def test_check_jaxpr_key_reuse(self): + try: + from jax.experimental.key_reuse import KeyReuseError + except ImportError: + self.skipTest("Test requires jax.experimental.key_reuse") + def f(seed): + key = jax.random.key(seed) + return jax.random.uniform(key) + jax.random.normal(key) + with jax.enable_checks(True): + with self.assertRaises(KeyReuseError): + jax.jit(f)(0) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py new file mode 100644 index 000000000000..705dc7baa7e6 --- /dev/null +++ b/tests/key_reuse_test.py @@ -0,0 +1,801 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest, parameterized +from functools import partial + +import numpy as np +import jax +from jax import core +import jax.numpy as jnp +from jax._src import prng +from jax._src import test_util as jtu +from jax.experimental.key_reuse._common import ( + assert_consumed, assert_unconsumed, consume, consume_p, unconsumed_copy_p) +from jax.experimental.key_reuse import ( + _forwarding, _simple, KeyReuseError, unconsumed_copy) + +from jax import config +config.parse_flags_with_absl() + + +key = jax.eval_shape(jax.random.key, 0) +key1D = jax.eval_shape(lambda key: key[None], key) + + +primitives_with_static_signatures = { + consume_p: (consume, key), + unconsumed_copy_p: (unconsumed_copy, key), + prng.random_bits_p: (jax.random.bits, key), + prng.random_fold_in_p: (jax.random.fold_in, key, 2), + prng.random_seed_p: (jax.random.key, 0), + prng.random_split_p: (jax.random.split, key), + prng.random_wrap_p: (jax.random.wrap_key_data, np.uint32([0, 0])), + prng.random_unwrap_p: (jax.random.key_data, key), + jax.random.random_gamma_p: (jax.random.gamma, key, 1.0), + jax.lax.broadcast_in_dim_p: (lambda key: key[None], key), + jax.lax.copy_p: (jnp.array, key), + jax.lax.convert_element_type_p: (lambda key: jnp.array(key, dtype=key.dtype), key), + jax.lax.device_put_p: (jax.device_put, key), + jax.lax.reshape_p: (lambda key: key.reshape((1,)), key), + jax.lax.squeeze_p: (jnp.squeeze, key1D), + jax.lax.dynamic_slice_p: (partial(jax.lax.dynamic_slice, slice_sizes=(1,)), key1D, (0,)), + jax.lax.dynamic_update_slice_p: (jax.lax.dynamic_update_slice, key1D, key1D, (0,)), +} + +# Primitive that is unknown to the key reuse machinery +unknown_p = core.Primitive("unknown") +unknown_p.def_abstract_eval(lambda x: x) +unknown_p.def_impl(lambda x: x) +def apply_unknown_primitive(key): + return unknown_p.bind(key) + + +@jtu.with_config( + jax_enable_custom_prng=False, + jax_enable_key_reuse_checks=False) +class KeyReuseUnitTestSimple(jtu.JaxTestCase): + def check_key_reuse(self, *args): + return _simple.check_key_reuse(*args) + + def test_assertions(self): + key = jax.random.key(0) + self.check_key_reuse(assert_unconsumed, key) + with self.assertRaises(AssertionError): + self.check_key_reuse(assert_consumed, key) + + def test_unknown(self): + def f(key): + assert_unconsumed(key) + key2 = apply_unknown_primitive(key) + assert_consumed(key) + assert_consumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_consume(self): + def f(key): + assert_unconsumed(key) + key2 = consume(key) + assert_consumed(key) + assert_consumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_unconsumed_copy(self): + def f(key): + assert_unconsumed(key) + consume(key) + assert_consumed(key) + key2 = unconsumed_copy(key) + assert_unconsumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_seed(self): + def f(): + key = jax.random.key(0) + assert_unconsumed(key) + self.check_key_reuse(f) + + def test_split(self): + def f(key): + assert_unconsumed(key) + key2 = jax.random.split(key) + assert_unconsumed(key2) + assert_consumed(key) + self.check_key_reuse(f, jax.random.key(0)) + + def test_fold_in(self): + def f(key): + assert_unconsumed(key) + key2 = jax.random.fold_in(key, 2) + assert_consumed(key) + assert_unconsumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_bits(self): + def f(key): + assert_unconsumed(key) + bits = jax.random.bits(key, (), 'uint32') + assert_consumed(key) + return bits + self.check_key_reuse(f, jax.random.key(0)) + + def test_wrap(self): + def f(key_data): + key = jax.random.wrap_key_data(key_data) + assert_unconsumed(key) + self.check_key_reuse(f, jax.random.PRNGKey(0)) + + def test_unwrap(self): + def f(key): + assert_unconsumed(key) + key_data = jax.random.key_data(key) + assert_consumed(key) + self.check_key_reuse(f, jax.random.key(0)) + + def test_gamma(self): + def f(key): + assert_unconsumed(key) + values = jax.random.gamma(key, 1.0) + assert_consumed(key) + return values + self.check_key_reuse(f, jax.random.key(0)) + + def test_broadcast_in_dim(self): + def f(key): + assert_unconsumed(key) + key2 = key[None] + assert_consumed(key) + assert_unconsumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_copy(self): + def f(key): + assert_unconsumed(key) + key2 = jnp.array(key, copy=True) + assert_consumed(key) + assert_unconsumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_device_put(self): + def f(key): + assert_unconsumed(key) + key2 = jax.device_put(key) + assert_consumed(key) + assert_unconsumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_squeeze(self): + def f(key): + assert_unconsumed(key) + key2 = jax.lax.squeeze(key, (0,)) + assert_consumed(key) + assert_unconsumed(key2) + self.check_key_reuse(f, jax.random.key(0)[None]) + + def test_reshape(self): + def f(key): + assert_unconsumed(key) + key2 = key.reshape(1, *key.shape) + assert_consumed(key) + assert_unconsumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_slice(self): + def f(keys): + assert_unconsumed(keys) + + assert_unconsumed(keys[0]) + assert_consumed(keys, np.array([True, False])) + + assert_unconsumed(keys[1]) + assert_consumed(keys, np.array([True, True])) + self.check_key_reuse(f, jax.random.split(jax.random.key(0))) + + def test_jit_can_consume_input(self): + def f(key): + assert_unconsumed(key) + jax.jit(jax.random.bits)(key) + assert_consumed(key) + self.check_key_reuse(f, jax.random.key(0)) + + def test_jit_can_return_consumed_output(self): + def f(): + def g(): + key = jax.random.key(0) + assert_unconsumed(key) + bits = jax.random.bits(key) + assert_consumed(key) + return bits, key + _, key = jax.jit(g)() + assert_consumed(key) + self.check_key_reuse(f) + + def test_jit_duplicate_inputs(self): + def f(key): + assert_unconsumed(key) + def g(key1, key2): + return jax.random.bits(key1) + _ = jax.jit(g)(key, key) + assert_consumed(key) + # TODO(jakevdp) handle this somehow? + with self.assertRaisesRegex(ValueError, "pjit with duplicate inputs"): + self.check_key_reuse(f, jax.random.key(0)) + + def test_jit_propagates_consumption_bit(self): + def f(key): + assert_unconsumed(key) + g = jax.jit(lambda: key) + key2 = g() + assert_consumed(key) + assert_unconsumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_jit_duplicate_outputs(self): + # TODO(jakevdp): implement this case + def f(key): + assert_unconsumed(key) + def g(key): + return key, key + key1, key2 = jax.jit(g)(key) + assert_consumed(key) + assert_unconsumed(key1) + assert_consumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_cond_both_consumed(self): + @jax.jit + def f(flag, key): + assert_unconsumed(key) + _ = jax.lax.cond( + flag, jax.random.uniform, jax.random.normal, key) + assert_consumed(key) + self.check_key_reuse(f, True, jax.random.key(0)) + + def test_cond_one_consumed(self): + @jax.jit + def f(flag, key): + assert_unconsumed(key) + _ = jax.lax.cond( + flag, jax.random.uniform, lambda k: 1.0, key) + assert_consumed(key) + self.check_key_reuse(f, True, jax.random.key(0)) + + def test_cond_neither_consumed(self): + @jax.jit + def f(flag, key): + assert_unconsumed(key) + _ = jax.lax.cond( + flag, lambda k: 0.0, lambda k: 1.0, key) + assert_unconsumed(key) + self.check_key_reuse(f, True, jax.random.key(0)) + + def test_simple_vmap(self): + @jax.jit + def f(seed): + key = jax.random.key(seed) + assert_unconsumed(key) + result = jax.random.uniform(key) + assert_consumed(key) + return result + self.check_key_reuse(f, 0) + self.check_key_reuse(jax.vmap(f), jnp.arange(4)) + + @parameterized.parameters(*primitives_with_static_signatures) + def test_jaxpr_type_signature(self, primitive): + func, *args = primitives_with_static_signatures[primitive] + signature = _simple.key_reuse_signatures[primitive] + jaxpr = jax.make_jaxpr(func)(*args) + self.assertEqual(signature, _simple.get_jaxpr_type_signature(jaxpr.jaxpr)) + + +@jtu.with_config( + jax_enable_custom_prng=False, + jax_enable_key_reuse_checks=False) +class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase): + def check_key_reuse(self, *args): + return _forwarding.check_key_reuse(*args) + + def test_assertions(self): + key = jax.random.key(0) + self.check_key_reuse(assert_unconsumed, key) + with self.assertRaises(AssertionError): + self.check_key_reuse(assert_consumed, key) + + def test_unknown(self): + def f(key): + assert_unconsumed(key) + key2 = apply_unknown_primitive(key) + assert_consumed(key) + assert_consumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_consume(self): + def f(key): + assert_unconsumed(key) + key2 = consume(key) + assert_consumed(key) + assert_consumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_unconsumed_copy(self): + def f(key): + assert_unconsumed(key) + consume(key) + assert_consumed(key) + key2 = unconsumed_copy(key) + assert_unconsumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_seed(self): + def f(): + key = jax.random.key(0) + assert_unconsumed(key) + self.check_key_reuse(f) + + def test_split(self): + def f(key): + assert_unconsumed(key) + key2 = jax.random.split(key) + assert_unconsumed(key2) + assert_consumed(key) + self.check_key_reuse(f, jax.random.key(0)) + + def test_fold_in(self): + def f(key): + assert_unconsumed(key) + key2 = jax.random.fold_in(key, 2) + assert_consumed(key) + assert_unconsumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_bits(self): + def f(key): + assert_unconsumed(key) + bits = jax.random.bits(key, (), 'uint32') + assert_consumed(key) + return bits + self.check_key_reuse(f, jax.random.key(0)) + + def test_wrap(self): + def f(key_data): + key = jax.random.wrap_key_data(key_data) + assert_unconsumed(key) + self.check_key_reuse(f, jax.random.PRNGKey(0)) + + def test_unwrap(self): + def f(key): + assert_unconsumed(key) + key_data = jax.random.key_data(key) + assert_consumed(key) + self.check_key_reuse(f, jax.random.key(0)) + + def test_gamma(self): + def f(key): + assert_unconsumed(key) + values = jax.random.gamma(key, 1.0) + assert_consumed(key) + return values + self.check_key_reuse(f, jax.random.key(0)) + + def test_broadcast_in_dim(self): + def f(key): + assert_unconsumed(key) + key2 = key[None] + assert_unconsumed(key) + assert_unconsumed(key2) + consume(key) + assert_consumed(key) + assert_consumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_copy(self): + def f(key): + assert_unconsumed(key) + key2 = jnp.array(key, copy=True) + assert_unconsumed(key) + assert_unconsumed(key2) + consume(key) + assert_consumed(key) + assert_consumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_device_put(self): + def f(key): + assert_unconsumed(key) + key2 = jax.device_put(key) + assert_unconsumed(key) + assert_unconsumed(key2) + consume(key) + assert_consumed(key) + assert_consumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_squeeze(self): + def f(key): + assert_unconsumed(key) + key2 = jax.lax.squeeze(key, (0,)) + assert_unconsumed(key) + assert_unconsumed(key2) + consume(key) + assert_consumed(key) + assert_consumed(key2) + self.check_key_reuse(f, jax.random.key(0)[None]) + + def test_reshape(self): + def f(key): + assert_unconsumed(key) + key2 = key.reshape(1, *key.shape) + assert_unconsumed(key) + assert_unconsumed(key2) + consume(key) + assert_consumed(key) + assert_consumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_slice(self): + def f(keys): + assert_unconsumed(keys) + + assert_unconsumed(keys[0]) + assert_consumed(keys, np.array([True, False])) + + assert_unconsumed(keys[1]) + assert_consumed(keys, np.array([True, True])) + self.check_key_reuse(f, jax.random.split(jax.random.key(0))) + + def test_jit_can_consume_input(self): + def f(key): + assert_unconsumed(key) + jax.jit(jax.random.bits)(key) + assert_consumed(key) + self.check_key_reuse(f, jax.random.key(0)) + + def test_jit_can_return_consumed_output(self): + def f(): + def g(): + key = jax.random.key(0) + assert_unconsumed(key) + bits = jax.random.bits(key) + assert_consumed(key) + return bits, key + _, key = jax.jit(g)() + assert_consumed(key) + self.check_key_reuse(f) + + def test_jit_duplicate_inputs(self): + def f(key): + assert_unconsumed(key) + def g(key1, key2): + assert_unconsumed(key1) + assert_unconsumed(key2) + return jax.random.bits(key1) + _ = jax.jit(g)(key, key) + assert_consumed(key) + self.check_key_reuse(f, jax.random.key(0)) + + def test_jit_propagates_consumption_bit(self): + def f(key): + assert_unconsumed(key) + g = jax.jit(lambda: key) + key2 = g() + assert_unconsumed(key) + assert_unconsumed(key2) + consume(key) + assert_consumed(key) + assert_consumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_jit_duplicate_outputs(self): + # TODO(jakevdp): implement this case + def f(key): + assert_unconsumed(key) + def g(key): + return key, key + key1, key2 = jax.jit(g)(key) + assert_unconsumed(key) + assert_unconsumed(key1) + assert_unconsumed(key2) + _ = jax.random.bits(key1) + assert_consumed(key) + assert_consumed(key1) + assert_consumed(key2) + self.check_key_reuse(f, jax.random.key(0)) + + def test_cond_both_consumed(self): + @jax.jit + def f(flag, key): + assert_unconsumed(key) + _ = jax.lax.cond( + flag, jax.random.uniform, jax.random.normal, key) + assert_consumed(key) + self.check_key_reuse(f, True, jax.random.key(0)) + + def test_cond_one_consumed(self): + @jax.jit + def f(flag, key): + assert_unconsumed(key) + _ = jax.lax.cond( + flag, jax.random.uniform, lambda k: 1.0, key) + assert_consumed(key) + self.check_key_reuse(f, True, jax.random.key(0)) + + def test_cond_neither_consumed(self): + @jax.jit + def f(flag, key): + assert_unconsumed(key) + _ = jax.lax.cond( + flag, lambda k: 0.0, lambda k: 1.0, key) + assert_unconsumed(key) + self.check_key_reuse(f, True, jax.random.key(0)) + + def test_simple_vmap(self): + @jax.jit + def f(seed): + key = jax.random.key(seed) + assert_unconsumed(key) + result = jax.random.uniform(key) + assert_consumed(key) + return result + self.check_key_reuse(f, 0) + self.check_key_reuse(jax.vmap(f), jnp.arange(4)) + + @parameterized.parameters(*primitives_with_static_signatures) + def test_jaxpr_type_signature(self, primitive): + func, *args = primitives_with_static_signatures[primitive] + signature = _forwarding.key_reuse_signatures[primitive] + jaxpr = jax.make_jaxpr(func)(*args) + self.assertEqual(signature, _forwarding.get_jaxpr_type_signature(jaxpr.jaxpr)) + + +@jtu.with_config(jax_enable_key_reuse_checks=False) +class KeyReuseIntegrationTest(jtu.JaxTestCase): + use_forwarding = True + random_bits_error = "In random_bits, key values .+ are already consumed.*" + random_split_error = "In random_split, key values .+ are already consumed.*" + generic_error = ".*key values .+ are already consumed.*" + + def check_key_reuse(self, f, *args): + if self.use_forwarding: + return _forwarding.check_key_reuse(f, *args) + else: + return _simple.check_key_reuse(f, *args) + + def test_reuse(self): + def f(): + key = jax.random.key(0) + return jax.random.uniform(key) + jax.random.uniform(key) + + with self.assertRaisesRegex(KeyReuseError, self.random_bits_error): + self.check_key_reuse(f) + + def test_reuse_after_split(self): + def f_good(): + key = jax.random.key(0) + key1, key2 = jax.random.split(key) + return jax.random.uniform(key1) + jax.random.uniform(key2) + self.check_key_reuse(f_good) + + def f_bad(): + key = jax.random.key(0) + _ = jax.random.split(key) + return jax.random.uniform(key) + + with self.assertRaisesRegex(KeyReuseError, self.random_bits_error): + self.check_key_reuse(f_bad) + + def f_bad_2(): + key = jax.random.key(0) + _ = jax.random.split(key) + key1, _ = jax.random.split(key) + return jax.random.uniform(key1) + + with self.assertRaisesRegex(KeyReuseError, self.random_split_error): + self.check_key_reuse(f_bad_2) + + def test_reuse_after_fold_in(self): + def f(): + key = jax.random.key(0) + _ = jax.random.fold_in(key, 1) + return jax.random.uniform(key) + + with self.assertRaisesRegex(KeyReuseError, self.random_bits_error): + self.check_key_reuse(f) + + def test_reuse_after_broadcast(self): + def f(): + key = jax.random.key(0) + key2 = key[None] + return jax.random.bits(key) + jax.random.bits(key2) + + with self.assertRaisesRegex(KeyReuseError, self.random_bits_error): + self.check_key_reuse(f) + + def test_reuse_after_reshape(self): + def f(): + key = jax.random.key(0) + key2 = key.reshape((1,)) + return jax.random.bits(key) + jax.random.bits(key2.squeeze()) + + with self.assertRaisesRegex(KeyReuseError, self.random_bits_error): + self.check_key_reuse(f) + + def test_reuse_after_squeeze(self): + def f(): + key = jax.random.split(jax.random.key(0), 1) + key2 = jax.lax.squeeze(key, (0,)) + return jax.random.bits(key.squeeze()) + jax.random.bits(key2) + + with self.assertRaisesRegex(KeyReuseError, self.generic_error): + self.check_key_reuse(f) + + def test_reuse_after_cond(self): + def f_good(key, condition): + return jax.lax.cond(condition, jax.random.uniform, jax.random.normal, key) + key = jax.random.key(0) + self.check_key_reuse(f_good, key, True) + self.check_key_reuse(f_good, key, False) + + # Check where both branches consume the key + def f_bad(key, condition): + r1 = jax.lax.cond(condition, jax.random.uniform, jax.random.normal, key) + return r1 + jax.random.uniform(key) + + with self.assertRaisesRegex(KeyReuseError, self.random_bits_error): + self.check_key_reuse(f_bad, key, True) + + # Check where only one branch consumes the key + def f_bad_2(key, condition): + r1 = jax.lax.cond(condition, jax.random.uniform, lambda key: 1.0, key) + return r1 + jax.random.uniform(key) + + with self.assertRaisesRegex(KeyReuseError, self.random_bits_error): + self.check_key_reuse(f_bad_2, key, True) + + def test_simple_scan(self): + def f_good(key): + def body_fun(key, _): + key, subkey = jax.random.split(key) + return key, jax.random.bits(subkey) + return jax.lax.scan(body_fun, key, xs=jnp.arange(10)) + self.check_key_reuse(f_good, jax.random.key(0)) + + def test_scan_sink_on_consts(self): + def f(key): + def body_fun(carry, _): + return carry, jax.random.uniform(key) + return jax.lax.scan(body_fun, None, xs=jnp.arange(10)) + with self.assertRaisesRegex(KeyReuseError, "scan body function leads to key reuse"): + self.check_key_reuse(f, jax.random.key(0)) + + def test_scan_reuse_in_body(self): + def f_bad(key): + def body_fun(key, _): + return key, jax.random.bits(key) + return jax.lax.scan(body_fun, key, xs=jnp.arange(10)) + with self.assertRaisesRegex(KeyReuseError, "scan body function leads to key reuse"): + self.check_key_reuse(f_bad, jax.random.key(0)) + + def test_scan_good_over_keys(self): + def f_scan_over_keys(key): + keys = jax.random.split(key, 5) + return jax.lax.map(jax.random.bits, keys) + self.check_key_reuse(f_scan_over_keys, jax.random.key(0)) + + def test_vmap(self): + @jax.vmap + def f_good(seed): + key = jax.random.key(seed) + return jax.random.bits(key) + self.check_key_reuse(f_good, jnp.arange(4)) + + @jax.vmap + def f_bad(seed): + key = jax.random.key(0) + return jax.random.bits(key) + jax.random.bits(key) + with self.assertRaisesRegex(KeyReuseError, self.random_bits_error): + self.check_key_reuse(f_bad, jnp.arange(4)) + + def test_while_simple(self): + def f(seed): + key = jax.random.key(seed) + def cond_fun(carry): + return carry[1] < 10 + def body_fun(carry): + key, subkey = jax.random.split(carry[0]) + return key, carry[1] + jax.random.uniform(subkey) + return jax.lax.while_loop(cond_fun, body_fun, (key, 0)) + self.check_key_reuse(f, 0) + + def test_while_bad_cond(self): + def f(seed): + key = jax.random.key(seed) + def cond_fun(carry): + i, key = carry + return i < jax.random.uniform(key) + def body_fun(carry): + i, key = carry + return i + 1, key + return jax.lax.while_loop(cond_fun, body_fun, (0, key)) + with self.assertRaisesRegex(KeyReuseError, "while_loop cond"): + self.check_key_reuse(f, 0) + + def test_while_bad_body(self): + def f(seed): + key = jax.random.key(seed) + def cond_fun(carry): + key, i = carry + return i < 5 + def body_fun(carry): + key, i = carry + return key, i + jax.random.randint(key, (), 1, 3) + return jax.lax.while_loop(cond_fun, body_fun, (key, 0)) + with self.assertRaisesRegex(KeyReuseError, "while_loop body function leads to key reuse"): + self.check_key_reuse(f, 0) + + def test_while_sink_on_body_consts(self): + def f(seed): + key = jax.random.key(seed) + def cond_fun(i): + return i < 5 + def body_fun(i): + return i + jax.random.randint(key, (), 1, 3) + return jax.lax.while_loop(cond_fun, body_fun, 0) + with self.assertRaisesRegex(KeyReuseError, "while_loop body function leads to key reuse"): + self.check_key_reuse(f, 0) + + def test_while_sink_on_cond_consts(self): + def f(seed): + key = jax.random.key(seed) + def cond_fun(i): + return i < jax.random.uniform(key) + def body_fun(i): + return i + 1 + return jax.lax.while_loop(cond_fun, body_fun, 0) + with self.assertRaisesRegex(KeyReuseError, "while_loop cond function leads to key reuse"): + self.check_key_reuse(f, 0) + + +class KeyReuseIntegrationTestSimple(KeyReuseIntegrationTest): + use_forwarding = False + + +@jtu.with_config(jax_enable_checks=False) +class KeyReuseGlobalFlags(KeyReuseIntegrationTest): + def test_key_reuse_flag(self): + + @jax.jit + def f_bad(key): + return jax.random.bits(key) + jax.random.bits(key) + + @jax.jit + def f_good(key): + return jax.random.bits(key) + + key = jax.random.key(0) + + with jax.enable_key_reuse_checks(False): + f_good(key) + f_bad(key) # No failure + + f_bad.clear_cache() + f_good.clear_cache() + + with jax.enable_key_reuse_checks(True): + f_good(key) + with self.assertRaisesRegex(KeyReuseError, "In random_bits.*"): + f_bad(key) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 29093a854c90..846dcb230fd9 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -1043,8 +1043,11 @@ def f(): return random.uniform( self.make_key(3), (308000000, 128), dtype=jnp.bfloat16) - # just lower, don't run, takes too long - jax.jit(f).lower() + # TODO(jakevdp): key reuse checks for this OOM because of slice masking. + # Can we fix this? + with jax.enable_key_reuse_checks(False): + # just lower, don't run, takes too long + jax.jit(f).lower() @jtu.sample_product(shape=[(3, 4)], logits_shape_base=[(3, 4), (3, 1), (1, 4)], diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 9db1741c017c..c23bbbd44667 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -60,6 +60,7 @@ ) +@jtu.with_config(jax_enable_key_reuse_checks=False) class DimExprTest(jtu.JaxTestCase): class AssertionType(enum.Enum): @@ -622,6 +623,7 @@ def check_shape_poly(tst, f_jax: Callable, *, return h.run_test(tst) +@jtu.with_config(jax_enable_key_reuse_checks=False) class ShapePolyTest(jtu.JaxTestCase): def test_simple_unary(self): @@ -2329,6 +2331,7 @@ def _flatten_harnesses(harnesses): return res +@jtu.with_config(jax_enable_key_reuse_checks=False) class ShapePolyHarnessesTest(jtu.JaxTestCase): """This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES."""