diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index e4db2f7ec358..b6cf69d218bf 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -663,22 +663,22 @@ def _shard_map_lowering_shardy( # Nested `ManualComputationOp`s cannot refer to axes that are already # manual. So figure out what axes are free thus far and get the new axis # context. - free_axis = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes - new_axis_context = sharding_impls.SPMDAxisContext(mesh, free_axis - auto) + free_axes = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes + new_axis_context = sharding_impls.SPMDAxisContext(mesh, free_axes - auto) else: new_axis_context = sharding_impls.SPMDAxisContext( mesh, frozenset(mesh.axis_names) - auto) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) args = (*ctx.dim_var_values, *in_nodes) - manual_axes = sub_ctx.axis_context.manual_axes mesh_shape = mesh.shape manual_axes_size = np.prod([mesh_shape[a] for a in manual_axes]) if manual_axes_size == 1: # No need for a `ManualComputationOp` if all manual axes are size 1. - out_nodes, _ = mlir.jaxpr_subcomp( - sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *args, - dim_var_values=ctx.dim_var_values) + with core.extend_axis_env_nd(tuple(mesh.shape.items())): + out_nodes, _ = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *args, + dim_var_values=ctx.dim_var_values) return out_nodes in_shardings = sdy.TensorShardingPerValueAttr.get(map(