Skip to content

Commit

Permalink
Fix the bug where process groups were not being properly released.
Browse files Browse the repository at this point in the history
  • Loading branch information
littsk committed Oct 18, 2023
1 parent 1100910 commit 745a91a
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 2 deletions.
22 changes: 22 additions & 0 deletions colossalai/cluster/process_group_mesh.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import itertools
from functools import reduce
from operator import mul
Expand Down Expand Up @@ -44,6 +45,27 @@ def __init__(self, *size: int) -> None:
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}

def __del__(self):
r"""
Destructor method for the ProcessGroupMesh class.
When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for
cleaning up any process groups that were created during the lifetime of the object.
Note:
All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed
when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release
system resources.
"""
for group in self._ranks_to_group.values():
dist.destroy_process_group(group)

for group in self._ranks_to_group.values():
print(group)

# Manually clear all process groups to conserve memory
gc.collect()

@property
def shape(self) -> Tuple[int, ...]:
return self._shape
Expand Down
50 changes: 48 additions & 2 deletions colossalai/tensor/d_tensor/layout_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, List, Tuple

import torch
import torch.distributed as dist

from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import *
Expand Down Expand Up @@ -438,11 +439,56 @@ def layout_converting(
MAX_TRANSFORM_STEPS = 20
total_steps = 0
transform_path = []
comm_action_sequence = []
comm_action_sequence: List[CommSpec] = []
spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))

if spec_pairs in self.cached_solution:
return self.cached_solution[spec_pairs]
# Cache hit

def _group_alive_check(cached_comm_action_sequence):
"""
Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method.
If not deleted, return True; otherwise, return False.
Args:
cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions.
Returns:
bool: True if all process groups are still registered, False if at least one has been deleted.
Raises:
RuntimeError: If there is an error while checking the status of a process group.
"""

# Collect all process groups used in communication actions from the cached sequence
used_process_groups = [
pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values()
]

# Check if each process group is still alive
for process_group in used_process_groups:
try:
dist.get_rank(process_group)
except RuntimeError as e:
# If the group is not registered, it means it has been deleted
if str(e) == (
f"Group {process_group} is not registered, please create group with torch.distributed.new_group API"
):
return False
else:
# Re-raise the exception if it's not related to group deletion
raise e
# All process groups are alive
return True

cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs]

if _group_alive_check(cached_comm_action_sequence):
# If all process groups have not been deleted, the cache is valid
return cached_transform_path, cached_comm_action_sequence
else:
# If at least one process group has been deleted, the cache is invalid, so delete it
del self.cached_solution[spec_pairs]

# We do nothing if the sharding spec is all the same.
if source_spec.spec_diff(target_spec) == 0:
Expand Down

0 comments on commit 745a91a

Please sign in to comment.