Skip to content

Commit

Permalink
[zero] refactor low level optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
botbw committed Jun 11, 2024
1 parent df6826d commit 480c8ea
Show file tree
Hide file tree
Showing 9 changed files with 853 additions and 972 deletions.
3 changes: 2 additions & 1 deletion colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.zero import LowLevelZeroOptimizer

from .dp_plugin_base import DPPluginBase
Expand Down Expand Up @@ -448,7 +449,7 @@ def configure(

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer, **zero_optim_kwargs, verbose=self.verbose
optimizer, **self.zero_optim_kwargs, verbose=self.verbose
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
Expand Down
17 changes: 17 additions & 0 deletions colossalai/zero/low_level/bookkeeping/gradient_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,23 @@ def get_working_grad_by_param_id(self, param_id) -> Tensor:

raise KeyError(f"Working gradient for param_id {param_id} not found.")

def get_master_grad_by_param_id(self, param_id) -> List[Tensor]:
"""
Return the working gradient for the specified parameter.
Args:
param_id (int): The index of the parameter.
Returns:
List[Tensor]: The the working gradient slices for the specified param_id.
"""

for group in self._grads_of_params.values():
if param_id in group.keys():
return group[param_id]

raise KeyError(f"Working gradient for param_id {param_id} not found.")

def reset_grads_by_group_id(self, group_id: int):
self._grads_of_params[group_id] = dict()

Expand Down
972 changes: 114 additions & 858 deletions colossalai/zero/low_level/low_level_optim.py

Large diffs are not rendered by default.

533 changes: 533 additions & 0 deletions colossalai/zero/low_level/low_level_strategy.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion tests/test_moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
assert local_name in ep_name, print(f"{local_name} != {ep_name}")
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
Expand Down
1 change: 1 addition & 0 deletions tests/test_moe/test_moe_load_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def run_dist(rank, world_size, port):
run_hybrid_zero_optim_test(rank, world_size, stage=2)


@pytest.mark.skip(reason="moe need to be refactored")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
Expand Down
121 changes: 75 additions & 46 deletions tests/test_moe/test_moe_zero_fwd_bwd.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,107 @@
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep
from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer
from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, MoeZeroStrategy
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, sync_local_from_ep


def run_zero_test(local_rank, stage=1):
def run_zero_test(local_rank):
dp_size = world_size = dist.get_world_size()
assert world_size >= 4, f"{world_size=}: at least 4 processes are required for this test (ep=2, moe_dp=2)"
criterion = torch.nn.CrossEntropyLoss()

ep_size = 2
extra_dp_size = world_size // ep_size

MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP")
moe_model = MoeModel().bfloat16()
moe_optimizer = torch.optim.Adam(moe_model.parameters())
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
moe_booster = Booster(plugin=moe_plugin)
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
MOE_MANAGER.setup(parallel="EP", mode="fixed", fixed_dp_size=extra_dp_size, fixed_ep_size=ep_size, fixed_pp_size=1)

zero_model = MoeModel().bfloat16().cuda()

dp_group = dist.group.WORLD
ep_group = MOE_MANAGER.parallel_info_dict[ep_size].ep_group
moe_extra_dp_group = MOE_MANAGER.parallel_info_dict[ep_size].dp_group

zero_params = list(filter(lambda x: not is_moe_tensor(x), zero_model.parameters()))
moe_params = list(filter(lambda x: is_moe_tensor(x), zero_model.parameters()))
print(f"{len(zero_params)=}, {len(moe_params)=}")
lr = 1e-3
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=lr)
zero_optimizer.param_groups.clear()
zero_optimizer.add_param_group({"params": zero_params})
zero_optimizer.add_param_group({"params": moe_params})

strategies = [
LowLevelOptStrategy(
param_group=zero_optimizer.param_groups[0],
process_group=dp_group,
overlap_communication=False,
partition_grad=True,
),
MoeZeroStrategy(
param_group=zero_optimizer.param_groups[1],
process_group=moe_extra_dp_group,
overlap_communication=True,
partition_grad=False,
),
]
zero_optimizer = LowLevelZeroOptimizer(
zero_optimizer,
strategies,
)

MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel=None)
zero_model = MoeModel().bfloat16()
delete_moe_info(zero_model)
zero_optimizer = torch.optim.Adam(zero_model.parameters())
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
zero_booster = Booster(plugin=zero_plugin)
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
sync_local_from_ep(zero_model, moe_model)
ddp_model = DDP(MoeModel().bfloat16().cuda(), static_graph=True)
delete_moe_info(ddp_model)
torch_optim = torch.optim.SGD(ddp_model.parameters(), lr=lr)
sync_local_from_ep(ddp_model, zero_model)

seed_all(42 + local_rank)
data = torch.randn(16, 4).bfloat16().cuda()
label = torch.randint(0, 4, (16,)).cuda()

zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
assert torch.allclose(zero_out, moe_out)
ddp_model.train()
zero_model.train()
ddp_out = criterion(ddp_model(data), label).float()
zero_out = criterion(zero_model(data), label).float()
assert torch.allclose(ddp_out, zero_out)
print(f"{local_rank=} {ddp_out.mean()=}")

ddp_out.backward()
zero_optimizer.backward(zero_out)

for (moe_name, moe_param), (zero_name, zero_param) in zip(
moe_model.module.named_parameters(), zero_model.module.named_parameters()
for (zero_name, zero_param), (ddp_name, ddp_param) in zip(
zero_model.named_parameters(), ddp_model.named_parameters()
):
assert moe_name == zero_name
moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param))
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
if hasattr(moe_param, "moe_info"):
assert len(moe_grad_list) == 0
if stage == 1:
zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape)
else:
zero_grad = zero_grad_list[0].view(moe_param.grad.shape)
assert torch.allclose(
moe_param.grad, zero_grad, atol=1e-5
), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}"
else:
assert len(moe_grad_list) > 0
assert len(moe_grad_list) == len(zero_grad_list)
for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list):
assert torch.allclose(moe_grad, zero_grad)
torch_grad = ddp_param.grad
zero_grad = zero_optimizer.get_param_grad(zero_param)
if is_moe_tensor(zero_param):
moe_grad_list = [torch.empty_like(zero_grad) for _ in range(ep_size)]
dist.all_gather(moe_grad_list, zero_grad, group=ep_group)
zero_grad = torch.cat(moe_grad_list, dim=0)
loose_close(torch_grad, zero_grad, dtype=torch_grad.dtype)


def run_dist(rank, world_size, port, stage):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
seed_all(42 + rank)
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_test(rank, stage=stage)


@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("stage", [1, 2])
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_zero_model(world_size, stage):
spawn(run_dist, world_size, stage=stage)
def test_moe_zero_model(world_size):
spawn(run_dist, world_size)


if __name__ == "__main__":
test_moe_zero_model(world_size=2, stage=1)
test_moe_zero_model(world_size=4)
152 changes: 97 additions & 55 deletions tests/test_moe/test_moe_zero_optim.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,125 @@
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep
from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer
from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, MoeZeroStrategy
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, sync_local_from_ep


def run_zero_test(local_rank, stage=1):
def run_zero_test(local_rank):
dp_size = world_size = dist.get_world_size()
assert world_size >= 4, f"{world_size=}: at least 4 processes are required for this test (ep=2, moe_dp=2)"
criterion = torch.nn.CrossEntropyLoss()

ep_size = 2
extra_dp_size = world_size // ep_size

MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP")
moe_model = MoeModel().bfloat16()
moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0)
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
moe_booster = Booster(plugin=moe_plugin)
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
MOE_MANAGER.setup(parallel="EP", mode="fixed", fixed_dp_size=extra_dp_size, fixed_ep_size=ep_size, fixed_pp_size=1)

zero_model = MoeModel().bfloat16().cuda()

dp_group = dist.group.WORLD
ep_group = MOE_MANAGER.parallel_info_dict[ep_size].ep_group
moe_extra_dp_group = MOE_MANAGER.parallel_info_dict[ep_size].dp_group

zero_params = list(filter(lambda x: not is_moe_tensor(x), zero_model.parameters()))
moe_params = list(filter(lambda x: is_moe_tensor(x), zero_model.parameters()))
print(f"{len(zero_params)=}, {len(moe_params)=}")
lr = 1e-3
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=lr)
zero_optimizer.param_groups.clear()
zero_optimizer.add_param_group({"params": zero_params})
zero_optimizer.add_param_group({"params": moe_params})

strategies = [
LowLevelOptStrategy(
param_group=zero_optimizer.param_groups[0],
process_group=dp_group,
overlap_communication=False,
partition_grad=True,
),
MoeZeroStrategy(
param_group=zero_optimizer.param_groups[1],
process_group=moe_extra_dp_group,
overlap_communication=True,
partition_grad=False,
),
]
zero_optimizer = LowLevelZeroOptimizer(
zero_optimizer,
strategies,
)

MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel=None)
zero_model = MoeModel().bfloat16()
delete_moe_info(zero_model)
sync_local_from_ep(zero_model, moe_model)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0)
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
zero_booster = Booster(plugin=zero_plugin)
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)

for (moe_name, moe_param), (zero_name, zero_param) in zip(
moe_model.named_parameters(), zero_model.named_parameters()
ddp_model = DDP(MoeModel().bfloat16().cuda(), static_graph=True)
delete_moe_info(ddp_model)
torch_optim = torch.optim.SGD(ddp_model.parameters(), lr=lr)
sync_local_from_ep(ddp_model, zero_model)

seed_all(42 + local_rank)
data = torch.randn(16, 4).bfloat16().cuda()
label = torch.randint(0, 4, (16,)).cuda()

ddp_model.train()
zero_model.train()
ddp_out = criterion(ddp_model(data), label).float()
zero_out = criterion(zero_model(data), label).float()
assert torch.allclose(ddp_out, zero_out)
print(f"{local_rank=} {ddp_out.mean()=}")

ddp_out.backward()
zero_optimizer.backward(zero_out)

for (zero_name, zero_param), (ddp_name, ddp_param) in zip(
zero_model.named_parameters(), ddp_model.named_parameters()
):
torch_grad = ddp_param.grad
zero_grad = zero_optimizer.get_param_grad(zero_param)
if is_moe_tensor(zero_param):
moe_grad_list = [torch.empty_like(zero_grad) for _ in range(ep_size)]
dist.all_gather(moe_grad_list, zero_grad, group=ep_group)
zero_grad = torch.cat(moe_grad_list, dim=0)
loose_close(torch_grad, zero_grad, dtype=torch_grad.dtype)

torch_optim.step()
zero_optimizer.step()

for (zero_name, zero_param), (ddp_name, ddp_param) in zip(
zero_model.named_parameters(), ddp_model.named_parameters()
):
if ".experts." in moe_name:
continue
assert moe_name == zero_name
assert torch.allclose(
moe_param.data, zero_param.data
), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}"

for _ in range(1):
data = torch.randn(2, 4).bfloat16().cuda()
label = torch.randint(0, 4, (2,)).cuda()

moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
assert torch.allclose(zero_out, moe_out)
moe_optimizer.step()
zero_optimizer.step()

for (moe_name, moe_param), (zero_name, zero_param) in zip(
moe_model.named_parameters(), zero_model.named_parameters()
):
assert moe_name == zero_name
if is_moe_tensor(moe_param):
param_size = moe_param.shape[0]
zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size]
loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype)

moe_optimizer.zero_grad()
zero_optimizer.zero_grad()
if is_moe_tensor(zero_param):
moe_param_list = [torch.empty_like(zero_param) for _ in range(ep_size)]
dist.all_gather(moe_param_list, zero_param, group=ep_group)
zero_param = torch.cat(moe_param_list, dim=0)
assert ddp_param.dtype == zero_param.dtype
ddp_param.numel() // dp_size
loose_close(
ddp_param,
zero_param,
dtype=ddp_param.dtype,
)


def run_dist(rank, world_size, port, stage):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
seed_all(42 + rank)
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_test(rank, stage=stage)


@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("stage", [1, 2])
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_zero_optim(world_size, stage):
spawn(run_dist, world_size, stage=stage)
def test_moe_zero_model(world_size):
spawn(run_dist, world_size)


if __name__ == "__main__":
test_moe_zero_optim(world_size=2, stage=1)
test_moe_zero_model(world_size=4)
Loading

0 comments on commit 480c8ea

Please sign in to comment.