Skip to content

Commit

Permalink
Add experimental static key reuse checking
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 11, 2023
1 parent 3651d4c commit a52d187
Show file tree
Hide file tree
Showing 25 changed files with 1,617 additions and 9 deletions.
13 changes: 13 additions & 0 deletions docs/jax.experimental.key_reuse.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
``jax.experimental.key_reuse`` module
=====================================

.. automodule:: jax.experimental.key_reuse

API
---

.. autosummary::
:toctree: _autosummary

unconsumed_copy
KeyReuseError
1 change: 1 addition & 0 deletions docs/jax.experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Experimental Modules
jax.experimental.custom_partitioning
jax.experimental.multihost_utils
jax.experimental.compilation_cache
jax.experimental.key_reuse

Experimental APIs
-----------------
Expand Down
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ py_library_providing_imports_info(
"_src/scipy/**/*.py",
"_src/state/**/*.py",
"_src/third_party/**/*.py",
"experimental/key_reuse/**/*.py",
"image/**/*.py",
"interpreters/**/*.py",
"lax/**/*.py",
Expand Down
1 change: 1 addition & 0 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from jax._src.config import (
config as config,
enable_checks as enable_checks,
enable_key_reuse_checks as enable_key_reuse_checks,
check_tracer_leaks as check_tracer_leaks,
checking_leaks as checking_leaks,
enable_custom_prng as enable_custom_prng,
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,12 @@ def update_thread_local_jit_state(**kw):
default=False,
help='Turn on invariant checking for JAX internals. Makes things slower.')

enable_key_reuse_checks = define_bool_state(
name='jax_enable_key_reuse_checks',
default=False,
help="Turn on experimental key reuse checking."
)

check_tracer_leaks = define_bool_state(
name='jax_check_tracer_leaks',
default=False,
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2766,6 +2766,13 @@ def ctx_factory():
msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
raise JaxprTypeError(msg) from None

# Run key reuse checker after validating jaxpr:
if config.enable_key_reuse_checks.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error
check_key_reuse_jaxpr(jaxpr)


def _check_jaxpr(
ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]],
jaxpr: Jaxpr
Expand Down
9 changes: 8 additions & 1 deletion jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,12 +372,19 @@ def cond_fun(vals):
def body_fun(vals):
[i], carry, ys = split_list(vals, [1, num_carry])
i_ = length - i - 1 if reverse else i
x = _map(partial(_dynamic_index_array, i_), x_avals, xs)
# TODO(jakevdp)[key-reuse]: this key reuse logic is not quite right,
# because the scan body may consume any keys within it.
# Import here to avoid circular imports
from jax.experimental import key_reuse
xs_unconsumed = _map(key_reuse.unconsumed_copy, xs)
x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed)
out_flat = f_impl(*consts, *carry, *x)
carry_out, y_updates = split_list(out_flat, [num_carry])
ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
return [i + 1] + carry_out + ys_out

# TODO(jakevdp)[key-reuse]: mark xs consumed here if f_impl consumes them.

ys_init = _map(partial(_empty_array, length), y_avals)
if length == 0:
return init + ys_init
Expand Down
5 changes: 5 additions & 0 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,11 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths):
if not config.dynamic_shapes.value:
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())

if config.enable_key_reuse_checks.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr
check_key_reuse_jaxpr(jaxpr)

if any(isinstance(c, core.Tracer) for c in consts):
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
final_consts = consts
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ def join(self, other):
@property
def shape(self):
if not isinstance(self.inner_aval, core.ShapedArray):
raise ValueError(f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`.")
raise AttributeError(f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`.")
return self.inner_aval.shape

@property
def dtype(self):
if not isinstance(self.inner_aval, core.UnshapedArray):
raise ValueError(f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`.")
raise AttributeError(f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`.")
return self.inner_aval.dtype

@core.aval_property
Expand Down
1 change: 1 addition & 0 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,7 @@ class JaxTestCase(parameterized.TestCase):
"""Base class for JAX tests including numerical checks and boilerplate."""
_default_config = {
'jax_enable_checks': True,
'jax_enable_key_reuse_checks': True,
'jax_numpy_dtype_promotion': 'strict',
'jax_numpy_rank_promotion': 'raise',
'jax_traceback_filtering': 'off',
Expand Down
5 changes: 5 additions & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src.lib import xla_client
from jax._src.numpy.ufuncs import logaddexp
from jax.experimental.key_reuse._common import unconsumed_copy_p

import tensorflow as tf # type: ignore[import]

Expand Down Expand Up @@ -1503,8 +1504,12 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"tridiagonal",
"eigh_jacobi",
"platform_index",
"assert_consumed_value",
"consume",
]

tf_impl[unconsumed_copy_p] = lambda x: x

tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
tf_impl[ad_util.zeros_like_p] = tf.zeros_like

Expand Down
3 changes: 2 additions & 1 deletion jax/experimental/jax2tf/tests/back_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@
config.parse_flags_with_absl()


@jtu.with_config(jax_legacy_prng_key='allow')
@jtu.with_config(jax_legacy_prng_key='allow',
jax_enable_key_reuse_checks=False)
class CompatTest(bctu.CompatTestBase):
def test_dummy(self):
# Tests the testing mechanism. Let this test run on all platforms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
from jax._src.internal_test_util import test_harnesses


@jtu.with_config(jax_legacy_prng_key='allow')
@jtu.with_config(jax_legacy_prng_key='allow',
jax_enable_key_reuse_checks=False)
class JaxPrimitiveTest(jtu.JaxTestCase):

# This test runs for all primitive harnesses. For each primitive "xxx" the
Expand Down
3 changes: 2 additions & 1 deletion jax/experimental/jax2tf/tests/tf_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence,
# TODO(necula): clean up the test harnesses to not require these flags
@jtu.with_config(jax_numpy_rank_promotion="allow",
jax_numpy_dtype_promotion='standard',
jax_legacy_prng_key="allow")
jax_legacy_prng_key="allow",
jax_enable_key_reuse_checks=False)
class JaxToTfTestCase(jtu.JaxTestCase):
# We want most tests to use the maximum available version, from the locally
# installed tfxla module and export.
Expand Down
45 changes: 45 additions & 0 deletions jax/experimental/key_reuse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Experimental Key Reuse Checking
-------------------------------
This module contains **experimental** functionality for detecting re-use of random
keys within JAX programs. It is under active development and the APIs here are likely
to change.
Key reuse checking can be enabled on `jit`-compiled functions using the
:func:`jax.enable_key_reuse_checks` configuration::
>>> import jax
>>> @jax.jit
... def f(key):
... return jax.random.uniform(key) + jax.random.normal(key)
...
>>> key = jax.random.key(0)
>>> with jax.enable_key_reuse_checks():
... f(key) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
KeyReuseError: In random_bits, key values a are already consumed.
This flag can also be set globally if you wish to enagle key reuse checks in
every JIT-compiled function.
"""

from jax.experimental.key_reuse._common import (
unconsumed_copy as unconsumed_copy,
KeyReuseError as KeyReuseError,
)
77 changes: 77 additions & 0 deletions jax/experimental/key_reuse/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import NamedTuple, Union
from jax import core
from jax.interpreters import batching, mlir
import numpy as np


class Sink(NamedTuple):
idx: int
mask: Union[bool, np.ndarray] = True


class Source(NamedTuple):
idx: int
mask: Union[bool, np.ndarray] = True


class KeyReuseSignature(NamedTuple):
sinks: list[Sink]
sources: list[Source]


class KeyReuseError(RuntimeError):
pass

consume_p = core.Primitive("consume")
consume_p.def_impl(lambda x: x)
consume_p.def_abstract_eval(lambda x: x)
batching.defvectorized(consume_p)
mlir.register_lowering(
consume_p,
mlir.lower_fun(lambda x: x, multiple_results=False))

def consume(key):
"""Consume the key and return a consumed copy."""
return consume_p.bind(key)

unconsumed_copy_p = core.Primitive("unconsumed_copy")
unconsumed_copy_p.def_impl(lambda x: x)
unconsumed_copy_p.def_abstract_eval(lambda x: x)
batching.defvectorized(unconsumed_copy_p)
mlir.register_lowering(
unconsumed_copy_p,
mlir.lower_fun(lambda x: x, multiple_results=False))

def unconsumed_copy(key):
"""Return a copy of key marked as unconsumed."""
return unconsumed_copy_p.bind(key)

assert_consumed_value_p = core.Primitive("assert_consumed_value")
assert_consumed_value_p.def_impl(lambda x, *, value: x)
assert_consumed_value_p.def_abstract_eval(lambda x, *, value: x)
batching.defvectorized(assert_consumed_value_p)
mlir.register_lowering(
assert_consumed_value_p,
mlir.lower_fun(lambda x, *, value: x, multiple_results=False))

def assert_unconsumed(key):
"""Assert that a key is unconsumed"""
assert_consumed_value_p.bind(key, value=False)

def assert_consumed(key, value=True):
"""Assert that a key is consumed"""
assert_consumed_value_p.bind(key, value=value)
49 changes: 49 additions & 0 deletions jax/experimental/key_reuse/_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Optional, Union

from jax import core
from jax.experimental.key_reuse import _forwarding
from jax.experimental.key_reuse import _simple
import numpy as np

# TODO(jakevdp) fix this
KeyReuseSignature = Any


def check_key_reuse(fun: Callable[..., Any], /, *args: Any,
use_forwarding: bool = True) -> KeyReuseSignature:
"""Function to statically check key reuse."""
if use_forwarding:
return _forwarding.check_key_reuse(fun, *args)
else:
return _simple.check_key_reuse(fun, *args)


def check_key_reuse_jaxpr(jaxpr: core.Jaxpr, *, use_forwarding: bool = True):
"""Check the jaxpr for key reuse."""
get_jaxpr_type_signature(jaxpr, use_forwarding=use_forwarding)


def get_jaxpr_type_signature(
jaxpr: core.Jaxpr, *,
consumed_inputs: Optional[list[Union[bool, np.ndarray]]] = None,
use_forwarding: bool = True,
) -> KeyReuseSignature:
"""Parse the jaxpr to determine key reuse signature"""
if use_forwarding:
return _forwarding.get_jaxpr_type_signature(jaxpr, consumed_inputs)
else:
return _simple.get_jaxpr_type_signature(jaxpr, consumed_inputs)
Loading

0 comments on commit a52d187

Please sign in to comment.