Skip to content

Commit

Permalink
[XLA:Python] Add support for explicitly creating the gloo tcp context.
Browse files Browse the repository at this point in the history
Pass the context to the CPU client explicitly.

PiperOrigin-RevId: 589898821
  • Loading branch information
hawkinsp authored and jax authors committed Dec 11, 2023
1 parent 3651d4c commit 384e29e
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@
help="Mock GPU client number of gpus.",
)

_CPU_ENABLE_GLOO_COLLECTIVES = config.DEFINE_bool(
name="jax_cpu_enable_gloo_collectives",
default=False,
help="If True, enable cross-process collectives on CPU using Gloo.",
)


# Backends

Expand Down Expand Up @@ -199,15 +205,28 @@ def register_backend_factory(name: str, factory: BackendFactory, *,


def make_cpu_client() -> xla_client.Client:
if xla_extension_version >= 216:
if xla_extension_version >= 223:
collectives: xla_client._xla.CpuCollectives | None = None
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
collectives = xla_client._xla.make_gloo_tcp_collectives( # type: ignore
distributed_client=distributed.global_state.client,
)
return xla_client.make_cpu_client( # type: ignore
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
num_nodes=distributed.global_state.num_processes,
collectives=collectives,
)
elif xla_extension_version >= 216:
# TODO(phawkins): remove type: ignore after updating jaxlib version used for
# mypy checks.
return xla_client.make_cpu_client( # type: ignore
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
num_nodes=distributed.global_state.num_processes,
)
return xla_client.make_cpu_client()
else:
return xla_client.make_cpu_client()


register_backend_factory(
Expand Down

0 comments on commit 384e29e

Please sign in to comment.