Skip to content

Commit

Permalink
#sdy use updated axis context when skipping ManualComputationOp cre…
Browse files Browse the repository at this point in the history
…ation.

Even when the total size of manual axes is 1, and we can skip creating the `ManualComputationOp`, we need to have the body of what was supposed to be the `shard_map` operate under this new context.

PiperOrigin-RevId: 684011448
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Oct 9, 2024
1 parent 9cf952a commit 823914a
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,22 +663,22 @@ def _shard_map_lowering_shardy(
# Nested `ManualComputationOp`s cannot refer to axes that are already
# manual. So figure out what axes are free thus far and get the new axis
# context.
free_axis = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes
new_axis_context = sharding_impls.SPMDAxisContext(mesh, free_axis - auto)
free_axes = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes
new_axis_context = sharding_impls.SPMDAxisContext(mesh, free_axes - auto)
else:
new_axis_context = sharding_impls.SPMDAxisContext(
mesh, frozenset(mesh.axis_names) - auto)
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
args = (*ctx.dim_var_values, *in_nodes)

manual_axes = sub_ctx.axis_context.manual_axes
mesh_shape = mesh.shape
manual_axes_size = np.prod([mesh_shape[a] for a in manual_axes])
if manual_axes_size == 1:
# No need for a `ManualComputationOp` if all manual axes are size 1.
out_nodes, _ = mlir.jaxpr_subcomp(
sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *args,
dim_var_values=ctx.dim_var_values)
with core.extend_axis_env_nd(tuple(mesh.shape.items())):
out_nodes, _ = mlir.jaxpr_subcomp(
sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *args,
dim_var_values=ctx.dim_var_values)
return out_nodes

in_shardings = sdy.TensorShardingPerValueAttr.get(map(
Expand Down

0 comments on commit 823914a

Please sign in to comment.