Skip to content

Commit

Permalink
Add config option to log or fatal when jax.Arrays are GCed.
Browse files Browse the repository at this point in the history
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.

PiperOrigin-RevId: 683828823
  • Loading branch information
ICGog authored and Google-ML-Automation committed Oct 9, 2024
1 parent 9cf952a commit b24c84c
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 24 deletions.
91 changes: 69 additions & 22 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast

from jax._src import lib
from jax._src.lib import guard_lib
from jax._src.lib import jax_jit
from jax._src.lib import transfer_guard_lib
from jax._src.lib import xla_client
from jax._src import logging_config

Expand Down Expand Up @@ -1596,7 +1596,7 @@ def _update_disable_jit_thread_local(val):
@contextlib.contextmanager
def explicit_device_put_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_put*() call."""
state = transfer_guard_lib.thread_local_state()
state = guard_lib.thread_local_state()
prev = state.explicit_device_put
state.explicit_device_put = True
try:
Expand All @@ -1607,7 +1607,7 @@ def explicit_device_put_scope() -> Iterator[None]:
@contextlib.contextmanager
def explicit_device_get_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_get() call."""
state = transfer_guard_lib.thread_local_state()
state = guard_lib.thread_local_state()
prev = state.explicit_device_get
state.explicit_device_get = True
try:
Expand All @@ -1616,19 +1616,19 @@ def explicit_device_get_scope() -> Iterator[None]:
state.explicit_device_get = prev

def _update_transfer_guard(state, key, val):
"""Applies the transfer guard level within transfer_guard_lib."""
"""Applies the transfer guard level within guard_lib."""
if val is None:
setattr(state, key, None)
elif val == 'allow':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.ALLOW)
setattr(state, key, guard_lib.TransferGuardLevel.ALLOW)
elif val == 'log':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG)
setattr(state, key, guard_lib.TransferGuardLevel.LOG)
elif val == 'disallow':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW)
setattr(state, key, guard_lib.TransferGuardLevel.DISALLOW)
elif val == 'log_explicit':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG_EXPLICIT)
setattr(state, key, guard_lib.TransferGuardLevel.LOG_EXPLICIT)
elif val == 'disallow_explicit':
setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT)
setattr(state, key, guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT)
else:
assert False, f'Invalid transfer guard level {val}'

Expand All @@ -1637,45 +1637,46 @@ def _update_transfer_guard(state, key, val):
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
# The default is applied by guard_lib. Use None here to avoid accidentally
# overriding --jax_transfer_guard.
default=None,
help=('Select the transfer guard level for host-to-device transfers. '
'Default is "allow".'),
update_global_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.global_state(), 'host_to_device', val),
guard_lib.global_state(), 'host_to_device', val),
update_thread_local_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.thread_local_state(), 'host_to_device', val))
guard_lib.thread_local_state(), 'host_to_device', val))

transfer_guard_device_to_device = optional_enum_state(
name='jax_transfer_guard_device_to_device',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
# The default is applied by guard_lib. Use None here to avoid accidentally
# overriding --jax_transfer_guard.
default=None,
help=('Select the transfer guard level for device-to-device transfers. '
'Default is "allow".'),
update_global_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.global_state(), 'device_to_device', val),
guard_lib.global_state(), 'device_to_device', val),
update_thread_local_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.thread_local_state(), 'device_to_device', val))
guard_lib.thread_local_state(), 'device_to_device', val))

transfer_guard_device_to_host = optional_enum_state(
name='jax_transfer_guard_device_to_host',
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# The default is applied by guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
default=None,
help=('Select the transfer guard level for device-to-host transfers. '
'Default is "allow".'),
update_global_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.global_state(), 'device_to_host', val),
guard_lib.global_state(), 'device_to_host', val
),
update_thread_local_hook=lambda val: _update_transfer_guard(
transfer_guard_lib.thread_local_state(), 'device_to_host', val))
guard_lib.thread_local_state(), 'device_to_host', val))

def _update_all_transfer_guard_global(val):
for name in ('jax_transfer_guard_host_to_device',
Expand All @@ -1688,8 +1689,8 @@ def _update_all_transfer_guard_global(val):
enum_values=[
'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard_*.
# The default is applied by guard_lib. Use None here to avoid accidentally
# overriding --jax_transfer_guard_*.
default=None,
help=('Select the transfer guard level for all transfers. This option is '
'set-only; the transfer guard level for a specific direction should '
Expand Down Expand Up @@ -1718,6 +1719,52 @@ def transfer_guard(new_val: str) -> Iterator[None]:
yield


if lib.xla_extension_version < 291:

def array_garbage_collection_guard(_val):
raise NotImplementedError(
'jaxlib version is too low for garbage collection guard'
)

else:
def _update_garbage_collection_guard(state, key, val):
"""Applies the transfer guard level within guard_lib."""
if val is None:
setattr(state, key, None)
elif val == 'allow':
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.ALLOW)
elif val == 'log':
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.LOG)
elif val == 'fatal':
setattr(state, key, guard_lib.GarbageCollectionGuardLevel.FATAL)
else:
assert False, f'Invalid garbage collection guard level {val}'

array_garbage_collection_guard = optional_enum_state(
name='jax_array_garbage_collection_guard',
enum_values=['allow', 'log', 'fatal'],
# The default is applied by guard_lib.
default=None,
help=(
'Select garbage collection guard level for "jax.Array" objects.\nThis'
' option can be used to control what happens when a "jax.Array"'
' object is garbage collected. It is desirable for "jax.Array"'
' objects to be freed by Python reference couting rather than garbage'
' collection in order to avoid device memory being held by the arrays'
' until garbage collection occurs.\n\nValid values are:\n * "allow":'
' do not log garbage collection of "jax.Array" objects.\n * "log":'
' log an error when a "jax.Array" is garbage collected.\n * "fatal":'
' fatal error if a "jax.Array" is garbage collected.\nDefault is'
' "allow".'
),
update_global_hook=lambda val: _update_garbage_collection_guard(
guard_lib.global_state(), 'garbage_collect_array', val
),
update_thread_local_hook=lambda val: _update_garbage_collection_guard(
guard_lib.thread_local_state(), 'garbage_collect_array', val
),
)

def _update_debug_log_modules(module_names_str: str | None):
logging_config.disable_all_debug_logging()
if not module_names_str:
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def _try_cuda_nvcc_import() -> str | None:

cuda_path = _cuda_path()

transfer_guard_lib = xla_client._xla.transfer_guard_lib
if xla_extension_version >= 291:
guard_lib = xla_client._xla.guard_lib
else:
guard_lib = xla_client._xla.transfer_guard_lib

Device = xla_client._xla.Device
4 changes: 3 additions & 1 deletion jaxlib/tools/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def patch_copy_mlir_import(src_file, dst_dir):
"pytree.pyi",
"transfer_guard_lib.pyi",
]
_OPTIONAL_XLA_EXTENSION_STUBS = []
_OPTIONAL_XLA_EXTENSION_STUBS = [
"guard_lib.pyi", # Will be required on xla_extension_version >= 291
]


def patch_copy_xla_extension_stubs(dst_dir):
Expand Down
5 changes: 5 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,11 @@ jax_multiplatform_test(
srcs = ["transfer_guard_test.py"],
)

jax_multiplatform_test(
name = "garbage_collection_guard_test",
srcs = ["garbage_collection_guard_test.py"],
)

jax_multiplatform_test(
name = "name_stack_test",
srcs = ["name_stack_test.py"],
Expand Down
77 changes: 77 additions & 0 deletions tests/garbage_collection_guard_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for garbage allocation guard."""

import gc
import io
from unittest import mock

from absl.testing import absltest
import jax
from jax._src import config
import jax._src.test_util as jtu
import jax.numpy as jnp

jax.config.parse_flags_with_absl()


# Helper class used to create a reference cycle.
class GarbageCollectionGuardTestNodeHelper:

def __init__(self, data):
self.data = data
self.next = None


def _create_array_cycle():
"""Creates a reference cycle of two jax.Arrays."""
n1 = GarbageCollectionGuardTestNodeHelper(jnp.ones((2, 2)))
n2 = GarbageCollectionGuardTestNodeHelper(jnp.zeros((2, 2)))
n1.next = n2
n2.next = n1


if jax._src.lib.xla_extension_version < 291:

class GarbageCollectionGuardTest(jtu.JaxTestCase):
pass

else:

class GarbageCollectionGuardTest(jtu.JaxTestCase):

def test_gced_array_is_not_logged_by_default(self):
mock_stderr = io.StringIO()
_create_array_cycle()
with mock.patch("sys.stderr", mock_stderr):
gc.collect()
self.assertNotIn(
"`jax.Array` was deleted by the Python garbage collector",
mock_stderr.getvalue(),
)

def test_gced_array_is_logged(self):
mock_stderr = io.StringIO()
with config.array_garbage_collection_guard("log"):
_create_array_cycle()
with mock.patch("sys.stderr", mock_stderr):
gc.collect()
self.assertIn(
"`jax.Array` was deleted by the Python garbage collector",
mock_stderr.getvalue(),
)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit b24c84c

Please sign in to comment.