Skip to content

Commit

Permalink
Replace torch.device(cuda) with torch.device(cuda:0) in devices initi…
Browse files Browse the repository at this point in the history
…alization (#3184)
  • Loading branch information
vblagoje authored and brandenchan committed Sep 21, 2022
1 parent a1bc553 commit 87781ee
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions haystack/modeling/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def initialize_device_settings(
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend="nccl")

# HF transformers v4.21.2 pipeline object doesn't accept torch.device("cuda"), it has to be an indexed cuda device
# TODO eventually remove once the limitation is fixed in HF transformers
device_to_replace = torch.device("cuda")
devices_to_use = [torch.device("cuda:0") if device == device_to_replace else device for device in devices_to_use]

logger.info(f"Using devices: {', '.join([str(device) for device in devices_to_use]).upper()}")
logger.info(f"Number of GPUs: {n_gpu}")
return devices_to_use, n_gpu
Expand Down

0 comments on commit 87781ee

Please sign in to comment.