From 601669419f28f27beac7e6d80b2f189349d03546 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 9 Oct 2024 21:03:57 -0700 Subject: [PATCH] [sharding_in_types] Normalize partition specs when creating avals so that P(None, None) and P() are treated as replicated and equivalent. Shardings on avals are always normalized. PiperOrigin-RevId: 684285486 --- jax/_src/array.py | 3 ++- jax/_src/lax/lax.py | 2 +- jax/_src/sharding_impls.py | 7 ++++--- tests/pjit_test.py | 10 +++++++++- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 83be3d418c50..4e0cd3d16875 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1029,7 +1029,8 @@ def make_array_from_single_device_arrays( def _get_aval_array(self): if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding): return self.aval.update(sharding=NamedSharding( - self.sharding.mesh.abstract_mesh, self.sharding.spec)) + self.sharding.mesh.abstract_mesh, + self.sharding.normalized_spec(self.ndim))) else: return self.aval api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index fce19b319692..32e723f31172 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2073,7 +2073,7 @@ def broadcasting_sharding_rule(name, *avals): msg = '{}: arrays must have same number of dimensions, got {}.' raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes))))) - specs = [a.sharding.normalized_spec for a in avals if a.shape] + specs = [a.sharding.spec for a in avals if a.shape] mesh = None for a in avals: diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 3aa8fafdd40a..73aa27bde41d 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -307,8 +307,7 @@ def is_fully_replicated(self) -> bool: def with_memory_kind(self, kind: str) -> NamedSharding: return NamedSharding(self.mesh, self.spec, memory_kind=kind) - @functools.cached_property - def normalized_spec(self): + def normalized_spec(self, ndim: int) -> PartitionSpec: out = [] for p in self._parsed_pspec: if p is None: @@ -319,7 +318,9 @@ def normalized_spec(self): out.append(p[0]) else: out.append(p) - return tuple(out) + if len(out) < ndim: + out.extend([None] * (ndim - len(out))) + return PartitionSpec(*out) def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index af9f55333be7..9d0389a4799f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4652,10 +4652,14 @@ def f(x, y): self.assertEqual(out.sharding, s) self.assertArraysEqual(out, (np_inp1 * np_inp2)) - out = f(arr1, arr1) + out = f(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x',), ('y',))))) self.assertEqual(out.sharding, s) self.assertArraysEqual(out, (np_inp1 * np_inp1)) + out = f(arr1, jax.device_put(np_inp2, NamedSharding(mesh, P()))) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, (np_inp1 * np_inp2)) + @jax.jit def g(x, y): return x * y @@ -4664,6 +4668,10 @@ def g(x, y): TypeError, "mul got incompatible shardings for broadcasting"): g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P('y', 'x')))) + with self.assertRaisesRegex( + TypeError, "mul got incompatible shardings for broadcasting"): + g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x', 'y'))))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):