Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try using contiguous rank to fix cuda_visible_devices #1926

Merged
merged 3 commits into from
Oct 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions python/raft-dask/raft_dask/common/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import time
import uuid
import warnings
from collections import Counter, OrderedDict
from collections import Counter, OrderedDict, defaultdict
from typing import Dict

from dask.distributed import default_client
from dask_cuda.utils import nvml_device_index
Expand Down Expand Up @@ -157,7 +158,7 @@ def worker_info(self, workers):
Builds a dictionary of { (worker_address, worker_port) :
(worker_rank, worker_port ) }
"""
ranks = _func_worker_ranks(self.client)
ranks = _func_worker_ranks(self.client, workers)
ports = (
_func_ucp_ports(self.client, workers) if self.comms_p2p else None
)
Expand Down Expand Up @@ -688,7 +689,7 @@ def _func_ucp_ports(client, workers):
return client.run(_func_ucp_listener_port, workers=workers)


def _func_worker_ranks(client):
def _func_worker_ranks(client, workers):
"""
For each worker connected to the client,
compute a global rank which is the sum
Expand All @@ -697,9 +698,16 @@ def _func_worker_ranks(client):
Parameters
----------
client (object): Dask client object.
"""
ranks = client.run(_get_nvml_device_index)
worker_ips = [_get_worker_ip(worker_address) for worker_address in ranks]
workers (list): List of worker addresses.
"""
# TODO: Add Test this function
# Running into build issues preventing testing
nvml_device_index_d = client.run(_get_nvml_device_index, workers=workers)
worker_ips = [
_get_worker_ip(worker_address)
for worker_address in nvml_device_index_d
]
ranks = _map_nvml_device_id_to_contiguous_range(nvml_device_index_d)
worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips)
return _append_rank_offset(ranks, worker_ip_offset_dict)

Expand All @@ -724,6 +732,38 @@ def _get_worker_ip(worker_address):
return ":".join(worker_address.split(":")[0:2])


def _map_nvml_device_id_to_contiguous_range(nvml_device_index_d: dict) -> dict:
"""
For each worker address in nvml_device_index_d, map the corresponding
worker rank in the range(0, num_workers_per_node) where rank is decided
by the NVML device index. Worker with the lowest NVML device index gets
rank 0, and worker with the highest NVML device index gets rank
num_workers_per_node-1.

Parameters
----------
nvml_device_index_d : dict
Dictionary of worker addresses mapped to their nvml device index.

Returns
-------
dict
Updated dictionary with worker addresses mapped to their rank.
"""

rank_per_ip: Dict[str, int] = defaultdict(int)

# Sort by NVML index to ensure that the worker
# with the lowest NVML index gets rank 0.
for worker, _ in sorted(nvml_device_index_d.items(), key=lambda x: x[1]):
ip = _get_worker_ip(worker)

nvml_device_index_d[worker] = rank_per_ip[ip]
rank_per_ip[ip] += 1

return nvml_device_index_d


def _get_rank_offset_across_nodes(worker_ips):
"""
Get a dictionary of worker IP addresses mapped to the cumulative count of
Expand Down