Skip to content

Commit

Permalink
[hotfix]: modify create_ep_hierarchical_group and add test (#5032)
Browse files Browse the repository at this point in the history
* feat: modify create_ep_hierarchical_group args

* test: add ep tests

* fix: remove get_process_group_ranks

* fix: fix src_rank
  • Loading branch information
CWHer authored Nov 17, 2023
1 parent 97cd0cd commit 3c08f17
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 28 deletions.
10 changes: 6 additions & 4 deletions colossalai/moe/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ class HierarchicalAllToAll(torch.autograd.Function):
def forward(
ctx: Any,
inputs: Tensor,
groups: Tuple[ProcessGroup],
groups: Tuple[ProcessGroup, ProcessGroup],
src_rank: int
) -> Tensor:
"""
Returns:
Expand All @@ -159,12 +160,12 @@ def forward(
# TODO: we can reduce comm volume by removing empty capacity
if ctx is not None:
ctx.comm_grps = groups
ctx.src_rank = src_rank
intra_node_group, inter_node_group = groups

local_world_size = dist.get_world_size(intra_node_group)
num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1
world_size = local_world_size * num_group
src_rank = dist.get_process_group_ranks(intra_node_group)[0]
outputs = torch.empty_like(inputs)

if dist.get_rank() == src_rank:
Expand Down Expand Up @@ -196,9 +197,10 @@ def forward(
return outputs

@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]:
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps),
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps, ctx.src_rank),
None,
None,
)

Expand Down
13 changes: 8 additions & 5 deletions colossalai/moe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.routers import MoeRouter, get_router_cls
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size


class SparseMLP(nn.Module):
Expand Down Expand Up @@ -105,8 +105,11 @@ def __init__(
if self.expert_parallel is not None:
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
self.ep_hierarchical_group = create_ep_hierarchical_group(
self.ep_group) if enable_hierarchical_comm else None
self.ep_hierarchical_group = None
if enable_hierarchical_comm:
self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
get_ep_group_ranks(self.experts)
)
self.dp_group = get_dp_group(self.experts)
else:
self.ep_group = None
Expand Down Expand Up @@ -225,10 +228,10 @@ def _ep_process(
"""
if not overlap or dist.get_world_size(self.ep_group) == 1:
if self.ep_hierarchical_group is not None:
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group)
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
return expert_output
else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
Expand Down
17 changes: 8 additions & 9 deletions colossalai/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,15 @@ def set_moe_args(config: Any, args: dict):


def create_ep_hierarchical_group(
ep_group: dist.ProcessGroup,
ep_group_ranks: List[int],
nproc_per_node: Optional[int] = None,
) -> Tuple[Optional[dist.ProcessGroup],
Optional[dist.ProcessGroup]]:
) -> Tuple[int, dist.ProcessGroup, Optional[dist.ProcessGroup]]:
"""
e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4
Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None
"""
assert dist.is_initialized(), "Please initialize torch.distributed first."
rank = dist.get_rank()
if nproc_per_node is None:
nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE")
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
Expand All @@ -197,29 +197,28 @@ def create_ep_hierarchical_group(
"nproc_per_node should be a divisor of world_size."
num_node = dist.get_world_size() // nproc_per_node

rank = dist.get_rank()
ep_ranks = dist.get_process_group_ranks(ep_group)

intra_src_rank = None
ep_intra_node_group = None
for i in range(num_node):
ep_intra_ranks = [
i * nproc_per_node + j
for j in range(nproc_per_node)
if j in ep_ranks
if j in ep_group_ranks
]
group = dist.new_group(ep_intra_ranks)
if rank in ep_intra_ranks:
assert ep_intra_node_group is None
ep_intra_node_group = group
intra_src_rank = ep_intra_ranks[0]

ep_inter_node_group = None
ep_inter_ranks = [
ep_ranks[0] + i * nproc_per_node
ep_group_ranks[0] + i * nproc_per_node
for i in range(num_node)
]
if len(ep_inter_ranks) > 1:
group = dist.new_group(ep_inter_ranks)
if rank in ep_inter_ranks:
ep_inter_node_group = group

return ep_intra_node_group, ep_inter_node_group
return intra_src_rank, ep_intra_node_group, ep_inter_node_group
6 changes: 4 additions & 2 deletions colossalai/tensor/moe_tensor/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
Expand Down Expand Up @@ -124,7 +126,7 @@ def get_dp_rank(tensor: torch.Tensor) -> int:
return dist.get_rank(get_dp_group(tensor))


def get_ep_group_ranks(tensor: torch.Tensor) -> int:
def get_ep_group_ranks(tensor: torch.Tensor) -> List[int]:
"""
Get the expert parallel group ranks of the given tensor.
Expand All @@ -137,7 +139,7 @@ def get_ep_group_ranks(tensor: torch.Tensor) -> int:
return tensor.moe_info.ep_group_ranks


def get_dp_group_ranks(tensor: torch.Tensor) -> int:
def get_dp_group_ranks(tensor: torch.Tensor) -> List[int]:
"""
Get the data parallel group ranks of the given tensor.
Expand Down
20 changes: 12 additions & 8 deletions tests/test_moe/test_moe_ep_tp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import warnings
from typing import Dict

import pytest
import torch
Expand Down Expand Up @@ -123,7 +124,7 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_
local_param.data.copy_(all_param.data)


def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int):
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict):
assert batch_size % world_size == 0

colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
Expand All @@ -133,8 +134,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP")
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
enable_hierarchical_comm = torch.__version__ >= "1.13.1"
enable_hierarchical_comm = config.get("enable_hierarchical_comm", False)
if enable_hierarchical_comm:
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
ep_model = SparseMLP(
num_experts=num_experts,
hidden_size=dim,
Expand All @@ -161,7 +163,6 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
tp_grad_handler = MoeGradientHandler(tp_model)

rank = dist.get_rank()
torch.cuda.manual_seed(seed)
input_data = torch.randn(batch_size, dim, device=get_current_device())
micro_batch_size = batch_size // world_size
index = rank * micro_batch_size
Expand Down Expand Up @@ -218,11 +219,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
@pytest.mark.parametrize("num_experts", [4, 64])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("dim", [64])
@pytest.mark.parametrize("seed", [42, 127])
@pytest.mark.parametrize("config", [
{"enable_hierarchical_comm": False},
{"enable_hierarchical_comm": True},
])
@rerun_if_address_is_in_use()
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed)
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)


if __name__ == '__main__':
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32, seed=42)
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)

0 comments on commit 3c08f17

Please sign in to comment.