Skip to content

Commit

Permalink
In python 3.11 async.run() always tries to convert repr of the result…
Browse files Browse the repository at this point in the history
… of a coroutine as integer while fetching sigint handler. This makes the test materialize the whole tensor in memory. This changes the test co-routine to return nothing to avoid triggering this bug.

python/cpython#112559

PiperOrigin-RevId: 586756112
  • Loading branch information
marksandler2 authored and jax authors committed Nov 30, 2023
1 parent 4de07b3 commit 569f06c
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions jax/experimental/array_serialization/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class CheckpointTest(jtu.JaxTestCase):
def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir):
os.rename(temp_ckpt_dir, final_ckpt_dir)

@unittest.skip('Broken at HEAD. b/313958844')
@jtu.skip_on_devices('cpu')
def test_memory_consumption(self):
global_mesh = jtu.create_global_mesh((2, 4), ('x', 'y'))
Expand All @@ -73,16 +72,18 @@ def test_memory_consumption(self):
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
manager.wait_until_finished()

deserialize_with_byte_limit = serialization.async_deserialize(
async def deserialize_with_byte_limit():
r = await serialization.async_deserialize(
sharding, tspec, inp_shape,
byte_limiter=serialization._LimitInFlightBytes(4_200_000))
r.block_until_ready()

tm.start()
asyncio.run(deserialize_with_byte_limit).block_until_ready()
asyncio.run(deserialize_with_byte_limit())
unused_current, peak = tm.get_traced_memory()
# NB: some padding + tensorstore overhead. It should always be
# less than array size (2048 * 4096 * 4 = 32M)
self.assertLess(peak, 10_000_000)

deserialize_wo_limit = serialization.async_deserialize(
sharding, tspec, inp_shape)
tm.clear_traces()
Expand Down

0 comments on commit 569f06c

Please sign in to comment.