Skip to content

Commit

Permalink
[sharding_in_types] Add support for nary ops to propagate sharding wh…
Browse files Browse the repository at this point in the history
…en 1 input is sharded and all others are replicated.

PiperOrigin-RevId: 684289345
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 10, 2024
1 parent b65be4e commit 351187d
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 10 deletions.
27 changes: 17 additions & 10 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding_impls import PmapSharding, NamedSharding, PartitionSpec
from jax._src.sharding_impls import (PmapSharding, NamedSharding,
PartitionSpec as P)
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis,
split_list, NumpyComplexWarning)
Expand Down Expand Up @@ -2072,7 +2073,7 @@ def broadcasting_sharding_rule(name, *avals):
msg = '{}: arrays must have same number of dimensions, got {}.'
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))

specs = [a.sharding.spec for a in avals if a.shape]
specs = [a.sharding.normalized_spec for a in avals if a.shape]

mesh = None
for a in avals:
Expand All @@ -2084,23 +2085,29 @@ def broadcasting_sharding_rule(name, *avals):
f' another mesh: {a.sharding.mesh}')
assert mesh is not None

result_specs = []
for ss, ds in zip(zip(*specs), zip(*shapes)):
result_specs = [None] * len(shapes[0])
for i, (ss, ds) in enumerate(zip(zip(*specs), zip(*shapes))):
if all(s == ss[0] for s in ss[1:]):
# if all dimension shardings are same, the resulting dimension sharding is
# the same.
result_specs.append(ss[0])
result_specs[i] = ss[0]
else:
non_trivial_s = [s for s, d in zip(ss, ds)
if not (core.definitely_equal(d, 1) and s is None)]
if not non_trivial_s:
result_specs.append(None)
result_specs[i] = None
elif all(non_trivial_s[0] == s for s in non_trivial_s[1:]):
result_specs.append(non_trivial_s[0])
result_specs[i] = non_trivial_s[0]
else:
raise TypeError(f'{name} got incompatible shardings for broadcasting: '
f'{", ".join(map(str, map(tuple, specs)))}.')
return NamedSharding(mesh, PartitionSpec(*result_specs))
for s in ss:
if result_specs[i] is None and s is not None:
result_specs[i] = s
elif (result_specs[i] is not None and s is not None and
result_specs[i] != s):
raise TypeError(
f'{name} got incompatible shardings for broadcasting: '
f'{", ".join(map(str, map(tuple, specs)))}.')
return NamedSharding(mesh, P(*result_specs))


def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False,
Expand Down
14 changes: 14 additions & 0 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,20 @@ def is_fully_replicated(self) -> bool:
def with_memory_kind(self, kind: str) -> NamedSharding:
return NamedSharding(self.mesh, self.spec, memory_kind=kind)

@functools.cached_property
def normalized_spec(self):
out = []
for p in self._parsed_pspec:
if p is None:
raise ValueError("UNCONSTRAINED is not supported yet.")
if not p:
out.append(None)
elif isinstance(p, tuple) and len(p) == 1:
out.append(p[0])
else:
out.append(p)
return tuple(out)

def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)

Expand Down
33 changes: 33 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4631,6 +4631,39 @@ def f(x):
else:
self.assertEqual(lowered_text.count('@Sharding'), 2)

@config.sharding_in_types(True)
def test_fully_replicated_array_mul(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp1 = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr1 = jax.device_put(np_inp1, s)

np_inp2 = np.arange(2).reshape(1, 2)
arr2 = jax.device_put(np_inp2, NamedSharding(mesh, P(None, None)))

@jax.jit
def f(x, y):
self.assertEqual(x.sharding.spec, s.spec)
out = x * y
self.assertEqual(out.sharding.spec, s.spec)
return out

out = f(arr1, arr2)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp1 * np_inp2))

out = f(arr1, arr1)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp1 * np_inp1))

@jax.jit
def g(x, y):
return x * y

with self.assertRaisesRegex(
TypeError, "mul got incompatible shardings for broadcasting"):
g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P('y', 'x'))))


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 351187d

Please sign in to comment.