-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
853 additions
and
972 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,78 +1,107 @@ | ||
import pytest | ||
import torch | ||
import torch.distributed as dist | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
|
||
import colossalai | ||
from colossalai.booster import Booster | ||
from colossalai.booster.plugin import LowLevelZeroPlugin | ||
from colossalai.moe.manager import MOE_MANAGER | ||
from colossalai.tensor.moe_tensor.api import is_moe_tensor | ||
from colossalai.testing import rerun_if_address_is_in_use, spawn | ||
from colossalai.testing.random import seed_all | ||
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep | ||
from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer | ||
from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, MoeZeroStrategy | ||
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, sync_local_from_ep | ||
|
||
|
||
def run_zero_test(local_rank, stage=1): | ||
def run_zero_test(local_rank): | ||
dp_size = world_size = dist.get_world_size() | ||
assert world_size >= 4, f"{world_size=}: at least 4 processes are required for this test (ep=2, moe_dp=2)" | ||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
ep_size = 2 | ||
extra_dp_size = world_size // ep_size | ||
|
||
MOE_MANAGER.__init__() | ||
MOE_MANAGER.setup(parallel="EP") | ||
moe_model = MoeModel().bfloat16() | ||
moe_optimizer = torch.optim.Adam(moe_model.parameters()) | ||
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") | ||
moe_booster = Booster(plugin=moe_plugin) | ||
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer) | ||
MOE_MANAGER.setup(parallel="EP", mode="fixed", fixed_dp_size=extra_dp_size, fixed_ep_size=ep_size, fixed_pp_size=1) | ||
|
||
zero_model = MoeModel().bfloat16().cuda() | ||
|
||
dp_group = dist.group.WORLD | ||
ep_group = MOE_MANAGER.parallel_info_dict[ep_size].ep_group | ||
moe_extra_dp_group = MOE_MANAGER.parallel_info_dict[ep_size].dp_group | ||
|
||
zero_params = list(filter(lambda x: not is_moe_tensor(x), zero_model.parameters())) | ||
moe_params = list(filter(lambda x: is_moe_tensor(x), zero_model.parameters())) | ||
print(f"{len(zero_params)=}, {len(moe_params)=}") | ||
lr = 1e-3 | ||
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=lr) | ||
zero_optimizer.param_groups.clear() | ||
zero_optimizer.add_param_group({"params": zero_params}) | ||
zero_optimizer.add_param_group({"params": moe_params}) | ||
|
||
strategies = [ | ||
LowLevelOptStrategy( | ||
param_group=zero_optimizer.param_groups[0], | ||
process_group=dp_group, | ||
overlap_communication=False, | ||
partition_grad=True, | ||
), | ||
MoeZeroStrategy( | ||
param_group=zero_optimizer.param_groups[1], | ||
process_group=moe_extra_dp_group, | ||
overlap_communication=True, | ||
partition_grad=False, | ||
), | ||
] | ||
zero_optimizer = LowLevelZeroOptimizer( | ||
zero_optimizer, | ||
strategies, | ||
) | ||
|
||
MOE_MANAGER.__init__() | ||
MOE_MANAGER.setup(parallel=None) | ||
zero_model = MoeModel().bfloat16() | ||
delete_moe_info(zero_model) | ||
zero_optimizer = torch.optim.Adam(zero_model.parameters()) | ||
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") | ||
zero_booster = Booster(plugin=zero_plugin) | ||
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer) | ||
sync_local_from_ep(zero_model, moe_model) | ||
ddp_model = DDP(MoeModel().bfloat16().cuda(), static_graph=True) | ||
delete_moe_info(ddp_model) | ||
torch_optim = torch.optim.SGD(ddp_model.parameters(), lr=lr) | ||
sync_local_from_ep(ddp_model, zero_model) | ||
|
||
seed_all(42 + local_rank) | ||
data = torch.randn(16, 4).bfloat16().cuda() | ||
label = torch.randint(0, 4, (16,)).cuda() | ||
|
||
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) | ||
moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer) | ||
assert torch.allclose(zero_out, moe_out) | ||
ddp_model.train() | ||
zero_model.train() | ||
ddp_out = criterion(ddp_model(data), label).float() | ||
zero_out = criterion(zero_model(data), label).float() | ||
assert torch.allclose(ddp_out, zero_out) | ||
print(f"{local_rank=} {ddp_out.mean()=}") | ||
|
||
ddp_out.backward() | ||
zero_optimizer.backward(zero_out) | ||
|
||
for (moe_name, moe_param), (zero_name, zero_param) in zip( | ||
moe_model.module.named_parameters(), zero_model.module.named_parameters() | ||
for (zero_name, zero_param), (ddp_name, ddp_param) in zip( | ||
zero_model.named_parameters(), ddp_model.named_parameters() | ||
): | ||
assert moe_name == zero_name | ||
moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param)) | ||
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) | ||
if hasattr(moe_param, "moe_info"): | ||
assert len(moe_grad_list) == 0 | ||
if stage == 1: | ||
zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape) | ||
else: | ||
zero_grad = zero_grad_list[0].view(moe_param.grad.shape) | ||
assert torch.allclose( | ||
moe_param.grad, zero_grad, atol=1e-5 | ||
), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}" | ||
else: | ||
assert len(moe_grad_list) > 0 | ||
assert len(moe_grad_list) == len(zero_grad_list) | ||
for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list): | ||
assert torch.allclose(moe_grad, zero_grad) | ||
torch_grad = ddp_param.grad | ||
zero_grad = zero_optimizer.get_param_grad(zero_param) | ||
if is_moe_tensor(zero_param): | ||
moe_grad_list = [torch.empty_like(zero_grad) for _ in range(ep_size)] | ||
dist.all_gather(moe_grad_list, zero_grad, group=ep_group) | ||
zero_grad = torch.cat(moe_grad_list, dim=0) | ||
loose_close(torch_grad, zero_grad, dtype=torch_grad.dtype) | ||
|
||
|
||
def run_dist(rank, world_size, port, stage): | ||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") | ||
seed_all(42 + rank) | ||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") | ||
run_zero_test(rank, stage=stage) | ||
|
||
|
||
@pytest.mark.dist | ||
@pytest.mark.parametrize("world_size", [2]) | ||
@pytest.mark.parametrize("stage", [1, 2]) | ||
@pytest.mark.parametrize("world_size", [4]) | ||
@rerun_if_address_is_in_use() | ||
def test_moe_zero_model(world_size, stage): | ||
spawn(run_dist, world_size, stage=stage) | ||
def test_moe_zero_model(world_size): | ||
spawn(run_dist, world_size) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_moe_zero_model(world_size=2, stage=1) | ||
test_moe_zero_model(world_size=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,83 +1,125 @@ | ||
import pytest | ||
import torch | ||
import torch.distributed as dist | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
|
||
import colossalai | ||
from colossalai.booster import Booster | ||
from colossalai.booster.plugin import LowLevelZeroPlugin | ||
from colossalai.moe.manager import MOE_MANAGER | ||
from colossalai.tensor.moe_tensor.api import is_moe_tensor | ||
from colossalai.testing import rerun_if_address_is_in_use, spawn | ||
from colossalai.testing.random import seed_all | ||
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep | ||
from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer | ||
from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, MoeZeroStrategy | ||
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, sync_local_from_ep | ||
|
||
|
||
def run_zero_test(local_rank, stage=1): | ||
def run_zero_test(local_rank): | ||
dp_size = world_size = dist.get_world_size() | ||
assert world_size >= 4, f"{world_size=}: at least 4 processes are required for this test (ep=2, moe_dp=2)" | ||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
ep_size = 2 | ||
extra_dp_size = world_size // ep_size | ||
|
||
MOE_MANAGER.__init__() | ||
MOE_MANAGER.setup(parallel="EP") | ||
moe_model = MoeModel().bfloat16() | ||
moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0) | ||
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") | ||
moe_booster = Booster(plugin=moe_plugin) | ||
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer) | ||
MOE_MANAGER.setup(parallel="EP", mode="fixed", fixed_dp_size=extra_dp_size, fixed_ep_size=ep_size, fixed_pp_size=1) | ||
|
||
zero_model = MoeModel().bfloat16().cuda() | ||
|
||
dp_group = dist.group.WORLD | ||
ep_group = MOE_MANAGER.parallel_info_dict[ep_size].ep_group | ||
moe_extra_dp_group = MOE_MANAGER.parallel_info_dict[ep_size].dp_group | ||
|
||
zero_params = list(filter(lambda x: not is_moe_tensor(x), zero_model.parameters())) | ||
moe_params = list(filter(lambda x: is_moe_tensor(x), zero_model.parameters())) | ||
print(f"{len(zero_params)=}, {len(moe_params)=}") | ||
lr = 1e-3 | ||
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=lr) | ||
zero_optimizer.param_groups.clear() | ||
zero_optimizer.add_param_group({"params": zero_params}) | ||
zero_optimizer.add_param_group({"params": moe_params}) | ||
|
||
strategies = [ | ||
LowLevelOptStrategy( | ||
param_group=zero_optimizer.param_groups[0], | ||
process_group=dp_group, | ||
overlap_communication=False, | ||
partition_grad=True, | ||
), | ||
MoeZeroStrategy( | ||
param_group=zero_optimizer.param_groups[1], | ||
process_group=moe_extra_dp_group, | ||
overlap_communication=True, | ||
partition_grad=False, | ||
), | ||
] | ||
zero_optimizer = LowLevelZeroOptimizer( | ||
zero_optimizer, | ||
strategies, | ||
) | ||
|
||
MOE_MANAGER.__init__() | ||
MOE_MANAGER.setup(parallel=None) | ||
zero_model = MoeModel().bfloat16() | ||
delete_moe_info(zero_model) | ||
sync_local_from_ep(zero_model, moe_model) | ||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0) | ||
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") | ||
zero_booster = Booster(plugin=zero_plugin) | ||
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer) | ||
|
||
for (moe_name, moe_param), (zero_name, zero_param) in zip( | ||
moe_model.named_parameters(), zero_model.named_parameters() | ||
ddp_model = DDP(MoeModel().bfloat16().cuda(), static_graph=True) | ||
delete_moe_info(ddp_model) | ||
torch_optim = torch.optim.SGD(ddp_model.parameters(), lr=lr) | ||
sync_local_from_ep(ddp_model, zero_model) | ||
|
||
seed_all(42 + local_rank) | ||
data = torch.randn(16, 4).bfloat16().cuda() | ||
label = torch.randint(0, 4, (16,)).cuda() | ||
|
||
ddp_model.train() | ||
zero_model.train() | ||
ddp_out = criterion(ddp_model(data), label).float() | ||
zero_out = criterion(zero_model(data), label).float() | ||
assert torch.allclose(ddp_out, zero_out) | ||
print(f"{local_rank=} {ddp_out.mean()=}") | ||
|
||
ddp_out.backward() | ||
zero_optimizer.backward(zero_out) | ||
|
||
for (zero_name, zero_param), (ddp_name, ddp_param) in zip( | ||
zero_model.named_parameters(), ddp_model.named_parameters() | ||
): | ||
torch_grad = ddp_param.grad | ||
zero_grad = zero_optimizer.get_param_grad(zero_param) | ||
if is_moe_tensor(zero_param): | ||
moe_grad_list = [torch.empty_like(zero_grad) for _ in range(ep_size)] | ||
dist.all_gather(moe_grad_list, zero_grad, group=ep_group) | ||
zero_grad = torch.cat(moe_grad_list, dim=0) | ||
loose_close(torch_grad, zero_grad, dtype=torch_grad.dtype) | ||
|
||
torch_optim.step() | ||
zero_optimizer.step() | ||
|
||
for (zero_name, zero_param), (ddp_name, ddp_param) in zip( | ||
zero_model.named_parameters(), ddp_model.named_parameters() | ||
): | ||
if ".experts." in moe_name: | ||
continue | ||
assert moe_name == zero_name | ||
assert torch.allclose( | ||
moe_param.data, zero_param.data | ||
), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}" | ||
|
||
for _ in range(1): | ||
data = torch.randn(2, 4).bfloat16().cuda() | ||
label = torch.randint(0, 4, (2,)).cuda() | ||
|
||
moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer) | ||
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) | ||
assert torch.allclose(zero_out, moe_out) | ||
moe_optimizer.step() | ||
zero_optimizer.step() | ||
|
||
for (moe_name, moe_param), (zero_name, zero_param) in zip( | ||
moe_model.named_parameters(), zero_model.named_parameters() | ||
): | ||
assert moe_name == zero_name | ||
if is_moe_tensor(moe_param): | ||
param_size = moe_param.shape[0] | ||
zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size] | ||
loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype) | ||
|
||
moe_optimizer.zero_grad() | ||
zero_optimizer.zero_grad() | ||
if is_moe_tensor(zero_param): | ||
moe_param_list = [torch.empty_like(zero_param) for _ in range(ep_size)] | ||
dist.all_gather(moe_param_list, zero_param, group=ep_group) | ||
zero_param = torch.cat(moe_param_list, dim=0) | ||
assert ddp_param.dtype == zero_param.dtype | ||
ddp_param.numel() // dp_size | ||
loose_close( | ||
ddp_param, | ||
zero_param, | ||
dtype=ddp_param.dtype, | ||
) | ||
|
||
|
||
def run_dist(rank, world_size, port, stage): | ||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") | ||
seed_all(42 + rank) | ||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") | ||
run_zero_test(rank, stage=stage) | ||
|
||
|
||
@pytest.mark.dist | ||
@pytest.mark.parametrize("world_size", [2]) | ||
@pytest.mark.parametrize("stage", [1, 2]) | ||
@pytest.mark.parametrize("world_size", [4]) | ||
@rerun_if_address_is_in_use() | ||
def test_moe_zero_optim(world_size, stage): | ||
spawn(run_dist, world_size, stage=stage) | ||
def test_moe_zero_model(world_size): | ||
spawn(run_dist, world_size) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_moe_zero_optim(world_size=2, stage=1) | ||
test_moe_zero_model(world_size=4) |
Oops, something went wrong.