Skip to content

Commit

Permalink
[sharding_in_types] Normalize partition specs when creating avals so …
Browse files Browse the repository at this point in the history
…that P(None, None) and P() are treated as replicated and equivalent. Shardings on avals are always normalized.

PiperOrigin-RevId: 684285486
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 10, 2024
1 parent 94abaf4 commit 6016694
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 6 deletions.
3 changes: 2 additions & 1 deletion jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,8 @@ def make_array_from_single_device_arrays(
def _get_aval_array(self):
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
return self.aval.update(sharding=NamedSharding(
self.sharding.mesh.abstract_mesh, self.sharding.spec))
self.sharding.mesh.abstract_mesh,
self.sharding.normalized_spec(self.ndim)))
else:
return self.aval
api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2073,7 +2073,7 @@ def broadcasting_sharding_rule(name, *avals):
msg = '{}: arrays must have same number of dimensions, got {}.'
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))

specs = [a.sharding.normalized_spec for a in avals if a.shape]
specs = [a.sharding.spec for a in avals if a.shape]

mesh = None
for a in avals:
Expand Down
7 changes: 4 additions & 3 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,7 @@ def is_fully_replicated(self) -> bool:
def with_memory_kind(self, kind: str) -> NamedSharding:
return NamedSharding(self.mesh, self.spec, memory_kind=kind)

@functools.cached_property
def normalized_spec(self):
def normalized_spec(self, ndim: int) -> PartitionSpec:
out = []
for p in self._parsed_pspec:
if p is None:
Expand All @@ -319,7 +318,9 @@ def normalized_spec(self):
out.append(p[0])
else:
out.append(p)
return tuple(out)
if len(out) < ndim:
out.extend([None] * (ndim - len(out)))
return PartitionSpec(*out)

def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
Expand Down
10 changes: 9 additions & 1 deletion tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4652,10 +4652,14 @@ def f(x, y):
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp1 * np_inp2))

out = f(arr1, arr1)
out = f(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x',), ('y',)))))
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp1 * np_inp1))

out = f(arr1, jax.device_put(np_inp2, NamedSharding(mesh, P())))
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp1 * np_inp2))

@jax.jit
def g(x, y):
return x * y
Expand All @@ -4664,6 +4668,10 @@ def g(x, y):
TypeError, "mul got incompatible shardings for broadcasting"):
g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P('y', 'x'))))

with self.assertRaisesRegex(
TypeError, "mul got incompatible shardings for broadcasting"):
g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x', 'y')))))


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 6016694

Please sign in to comment.