From 384e29e30d000f1e9c7d7d4a52eadb0ef8a8141a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 11 Dec 2023 12:05:10 -0800 Subject: [PATCH] [XLA:Python] Add support for explicitly creating the gloo tcp context. Pass the context to the CPU client explicitly. PiperOrigin-RevId: 589898821 --- jax/_src/xla_bridge.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 64da2b02094c..8835afe11e96 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -96,6 +96,12 @@ help="Mock GPU client number of gpus.", ) +_CPU_ENABLE_GLOO_COLLECTIVES = config.DEFINE_bool( + name="jax_cpu_enable_gloo_collectives", + default=False, + help="If True, enable cross-process collectives on CPU using Gloo.", +) + # Backends @@ -199,7 +205,19 @@ def register_backend_factory(name: str, factory: BackendFactory, *, def make_cpu_client() -> xla_client.Client: - if xla_extension_version >= 216: + if xla_extension_version >= 223: + collectives: xla_client._xla.CpuCollectives | None = None + if _CPU_ENABLE_GLOO_COLLECTIVES.value: + collectives = xla_client._xla.make_gloo_tcp_collectives( # type: ignore + distributed_client=distributed.global_state.client, + ) + return xla_client.make_cpu_client( # type: ignore + distributed_client=distributed.global_state.client, + node_id=distributed.global_state.process_id, + num_nodes=distributed.global_state.num_processes, + collectives=collectives, + ) + elif xla_extension_version >= 216: # TODO(phawkins): remove type: ignore after updating jaxlib version used for # mypy checks. return xla_client.make_cpu_client( # type: ignore @@ -207,7 +225,8 @@ def make_cpu_client() -> xla_client.Client: node_id=distributed.global_state.process_id, num_nodes=distributed.global_state.num_processes, ) - return xla_client.make_cpu_client() + else: + return xla_client.make_cpu_client() register_backend_factory(