Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hotfix]: modify create_ep_hierarchical_group and add test #5032

Merged
merged 4 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions colossalai/moe/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ class HierarchicalAllToAll(torch.autograd.Function):
def forward(
ctx: Any,
inputs: Tensor,
groups: Tuple[ProcessGroup],
groups: Tuple[ProcessGroup, ProcessGroup],
src_rank: int
) -> Tensor:
"""
Returns:
Expand All @@ -159,12 +160,12 @@ def forward(
# TODO: we can reduce comm volume by removing empty capacity
if ctx is not None:
ctx.comm_grps = groups
ctx.src_rank = src_rank
intra_node_group, inter_node_group = groups

local_world_size = dist.get_world_size(intra_node_group)
num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1
world_size = local_world_size * num_group
src_rank = dist.get_process_group_ranks(intra_node_group)[0]
outputs = torch.empty_like(inputs)

if dist.get_rank() == src_rank:
Expand Down Expand Up @@ -196,9 +197,10 @@ def forward(
return outputs

@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]:
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps),
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps, ctx.src_rank),
None,
None,
)

Expand Down
13 changes: 8 additions & 5 deletions colossalai/moe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.routers import MoeRouter, get_router_cls
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size


class SparseMLP(nn.Module):
Expand Down Expand Up @@ -105,8 +105,11 @@ def __init__(
if self.expert_parallel is not None:
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
self.ep_hierarchical_group = create_ep_hierarchical_group(
self.ep_group) if enable_hierarchical_comm else None
self.ep_hierarchical_group = None
if enable_hierarchical_comm:
self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
get_ep_group_ranks(self.experts)
)
self.dp_group = get_dp_group(self.experts)
else:
self.ep_group = None
Expand Down Expand Up @@ -225,10 +228,10 @@ def _ep_process(
"""
if not overlap or dist.get_world_size(self.ep_group) == 1:
if self.ep_hierarchical_group is not None:
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group)
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
return expert_output
else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
Expand Down
17 changes: 8 additions & 9 deletions colossalai/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,15 @@ def set_moe_args(config: Any, args: dict):


def create_ep_hierarchical_group(
ep_group: dist.ProcessGroup,
ep_group_ranks: List[int],
nproc_per_node: Optional[int] = None,
) -> Tuple[Optional[dist.ProcessGroup],
Optional[dist.ProcessGroup]]:
) -> Tuple[int, dist.ProcessGroup, Optional[dist.ProcessGroup]]:
"""
e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4
Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None
"""
assert dist.is_initialized(), "Please initialize torch.distributed first."
rank = dist.get_rank()
if nproc_per_node is None:
nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE")
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
Expand All @@ -197,29 +197,28 @@ def create_ep_hierarchical_group(
"nproc_per_node should be a divisor of world_size."
num_node = dist.get_world_size() // nproc_per_node

rank = dist.get_rank()
ep_ranks = dist.get_process_group_ranks(ep_group)

intra_src_rank = None
ep_intra_node_group = None
for i in range(num_node):
ep_intra_ranks = [
i * nproc_per_node + j
for j in range(nproc_per_node)
if j in ep_ranks
if j in ep_group_ranks
]
group = dist.new_group(ep_intra_ranks)
if rank in ep_intra_ranks:
assert ep_intra_node_group is None
ep_intra_node_group = group
intra_src_rank = ep_intra_ranks[0]

ep_inter_node_group = None
ep_inter_ranks = [
ep_ranks[0] + i * nproc_per_node
ep_group_ranks[0] + i * nproc_per_node
for i in range(num_node)
]
if len(ep_inter_ranks) > 1:
group = dist.new_group(ep_inter_ranks)
if rank in ep_inter_ranks:
ep_inter_node_group = group

return ep_intra_node_group, ep_inter_node_group
return intra_src_rank, ep_intra_node_group, ep_inter_node_group
6 changes: 4 additions & 2 deletions colossalai/tensor/moe_tensor/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
Expand Down Expand Up @@ -124,7 +126,7 @@ def get_dp_rank(tensor: torch.Tensor) -> int:
return dist.get_rank(get_dp_group(tensor))


def get_ep_group_ranks(tensor: torch.Tensor) -> int:
def get_ep_group_ranks(tensor: torch.Tensor) -> List[int]:
"""
Get the expert parallel group ranks of the given tensor.

Expand All @@ -137,7 +139,7 @@ def get_ep_group_ranks(tensor: torch.Tensor) -> int:
return tensor.moe_info.ep_group_ranks


def get_dp_group_ranks(tensor: torch.Tensor) -> int:
def get_dp_group_ranks(tensor: torch.Tensor) -> List[int]:
"""
Get the data parallel group ranks of the given tensor.

Expand Down
20 changes: 12 additions & 8 deletions tests/test_moe/test_moe_ep_tp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import warnings
from typing import Dict

import pytest
import torch
Expand Down Expand Up @@ -123,7 +124,7 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_
local_param.data.copy_(all_param.data)


def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int):
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict):
assert batch_size % world_size == 0

colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
Expand All @@ -133,8 +134,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP")
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
enable_hierarchical_comm = torch.__version__ >= "1.13.1"
enable_hierarchical_comm = config.get("enable_hierarchical_comm", False)
if enable_hierarchical_comm:
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
ep_model = SparseMLP(
num_experts=num_experts,
hidden_size=dim,
Expand All @@ -161,7 +163,6 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
tp_grad_handler = MoeGradientHandler(tp_model)

rank = dist.get_rank()
torch.cuda.manual_seed(seed)
input_data = torch.randn(batch_size, dim, device=get_current_device())
micro_batch_size = batch_size // world_size
index = rank * micro_batch_size
Expand Down Expand Up @@ -218,11 +219,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
@pytest.mark.parametrize("num_experts", [4, 64])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("dim", [64])
@pytest.mark.parametrize("seed", [42, 127])
@pytest.mark.parametrize("config", [
{"enable_hierarchical_comm": False},
{"enable_hierarchical_comm": True},
])
@rerun_if_address_is_in_use()
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed)
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)


if __name__ == '__main__':
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32, seed=42)
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)
Loading