Skip to content

Commit

Permalink
Merge pull request #18741 from mattjj:shmap-test-fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586710378
  • Loading branch information
jax authors committed Nov 30, 2023
2 parents fe237cd + 5c2635c commit 11d7a2b
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,11 @@ def tearDownModule():
xla_bridge.get_backend.cache_clear()


@jtu.ignore_warning(category=DeprecationWarning,
message="arr.device_buffers? is deprecated")
class ShardMapTest(jtu.JaxTestCase):

def test_identity(self):
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
assert a.device_buffers[0].shape == (4, 2)
assert a.addressable_data(0).shape == (4, 2)

def identity(x):
return x
Expand All @@ -111,11 +109,11 @@ def fwd(a):
return c

c = fwd(a)
self.assertEqual(c.device_buffers[0].shape, (4, 2))
self.assertEqual(c.addressable_data(0).shape, (4, 2))

def test_all_gather(self):
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
assert a.device_buffers[0].shape == (4, 2)
assert a.addressable_data(0).shape == (4, 2)

# NOTE(mattjj): to use out_specs=P(None, ('x', 'y')), we need to use
# all_gather_invariant primitive, which differs in its output replication
Expand All @@ -127,13 +125,13 @@ def fwd(a):
return lax.all_gather(a, 'z', axis=0, tiled=True)

c = fwd(a)
self.assertEqual(c.device_buffers[0].shape, (8, 2))
self.assertEqual(c.addressable_data(0).shape, (8, 2))

def test_matmul_partial(self):
raise unittest.SkipTest("invalid replication asserted by out_spec?")

mesh, a, b = create_inputs(P('z', 'y'), P('y', None))
assert a.device_buffers[0].shape == (4, 4)
assert a.addressable_data(0).shape == (4, 4)

@jax.jit
@partial(shard_map, mesh=mesh,
Expand All @@ -143,11 +141,11 @@ def fwd(a):
return c

c = fwd(a)
self.assertEqual(c.device_buffers[0].shape, (4, 8))
self.assertEqual(c.addressable_data(0).shape, (4, 8))

def test_matmul_reduce_scatter(self):
mesh, a, b = create_inputs(P('z', 'y'), P('y', None))
assert a.device_buffers[0].shape == (4, 4)
assert a.addressable_data(0).shape == (4, 4)

@jax.jit
@partial(shard_map, mesh=mesh,
Expand All @@ -158,7 +156,7 @@ def fwd(a, b):
return lax.psum_scatter(c, 'y', scatter_dimension=0, tiled=True)

c = fwd(a, b)
self.assertEqual(c.device_buffers[0].shape, (2, 8))
self.assertEqual(c.addressable_data(0).shape, (2, 8))

def test_collective_permute(self):
devices = np.array(jax.devices()[:8]) # Take up to 8 devices
Expand Down

0 comments on commit 11d7a2b

Please sign in to comment.