Skip to content

Commit

Permalink
[hotfix]Solve the compatibility issue of zero refactor (#5823)
Browse files Browse the repository at this point in the history
* [moe refactor] update unit test with the refactored ZeRO and remove useless test

* move moe checkpoint to checkpoint folder and exchange global axis to class member

* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug

* fix zero unit test

* Add an assertion to prevent users from using it incorrectly

* Modify function parameter names to resolve compatibility issues
  • Loading branch information
Hz188 authored Jun 17, 2024
1 parent ba0115a commit a10802e
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 90
steps:
- name: Check GPU Availability # ensure all GPUs have enough memory
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/compatiblity_test_on_dispatch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 120
steps:
- name: Install dependencies
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/compatiblity_test_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 120
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/compatiblity_test_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 120
steps:
- name: Install dependencies
Expand Down
2 changes: 0 additions & 2 deletions applications/ColossalMoE/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.checkpoint import MoECheckpointIO


def parse_args():
Expand Down Expand Up @@ -69,7 +68,6 @@ def main():
ep_size=ep_size,
zero_stage=1,
precision=args.precision,
checkpoint_io=MoECheckpointIO,
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
)
Expand Down
2 changes: 0 additions & 2 deletions applications/ColossalMoE/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.checkpoint import MoECheckpointIO
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
Expand Down Expand Up @@ -158,7 +157,6 @@ def main():
enable_jit_fused=args.use_kernel,
precision=args.precision,
zero_stage=args.zero_stage,
checkpoint_io=MoECheckpointIO,
)

else:
Expand Down
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def configure(

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer, **self.zero_optim_kwargs, verbose=self.verbose
optimizer, **zero_optim_kwargs, verbose=self.verbose
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
Expand Down
44 changes: 22 additions & 22 deletions colossalai/zero/low_level/low_level_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class LowLevelOptStrategyBase(ABC):
def __init__(
self,
param_group,
process_group,
dp_process_group,
master_weights,
partition_grad,
cpu_offload,
Expand All @@ -46,14 +46,14 @@ def __init__(
self.param_group = param_group
self._dtype = self.param_group["params"][0].dtype

if process_group is None: # if process_group is none, convert to default explicitly
process_group = dist.group.WORLD
if dp_process_group is None: # if dp_process_group is none, convert to default explicitly
dp_process_group = dist.group.WORLD

self.process_group = process_group
self.dp_process_group = dp_process_group

# if process_group is none, will use the default one
self._local_rank = dist.get_rank(group=self.process_group)
self._world_size = dist.get_world_size(group=self.process_group)
# if dp_process_group is none, will use the default one
self._local_rank = dist.get_rank(group=self.dp_process_group)
self._world_size = dist.get_world_size(group=self.dp_process_group)

# master weights copy
self._master_weights = master_weights
Expand All @@ -65,9 +65,9 @@ def __init__(

# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(process_group)
self._grad_store = GradientStore(process_group, partition_grad=partition_grad)
self._bucket_store = BucketStore(process_group, reduce_bucket_size=reduce_bucket_size)
self._param_store = ParameterStore(dp_process_group)
self._grad_store = GradientStore(dp_process_group, partition_grad=partition_grad)
self._bucket_store = BucketStore(dp_process_group, reduce_bucket_size=reduce_bucket_size)

# working and master params for mixed precision training
group_params = []
Expand Down Expand Up @@ -224,7 +224,7 @@ def _run_reduction(self):
flat_grads = flat_grads.to(self._communication_dtype)

if not self._partition_grad:
dist.all_reduce(flat_grads, group=self.process_group)
dist.all_reduce(flat_grads, group=self.dp_process_group)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)

Expand All @@ -234,7 +234,7 @@ def _run_reduction(self):
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.process_group)
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_process_group)

if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
Expand Down Expand Up @@ -294,7 +294,7 @@ def state_dict(self, optim: torch.optim.Optimizer) -> Dict:
gather_tensor = [
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
]
dist.all_gather(gather_tensor, v, group=self.process_group)
dist.all_gather(gather_tensor, v, group=self.dp_process_group)
param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
Expand Down Expand Up @@ -328,7 +328,7 @@ def get_grad_norm(self, norm_type: int = 2) -> float:
total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.process_group)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_process_group)
total_norm = total_norm_cuda.item()

else:
Expand All @@ -342,7 +342,7 @@ def get_grad_norm(self, norm_type: int = 2) -> float:
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
)
torch.distributed.all_reduce(
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.process_group
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_process_group
)
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)

Expand Down Expand Up @@ -381,7 +381,7 @@ def get_param_grad(self, param):
return None
if self._partition_grad:
tensor_list = [torch.empty_like(grad_maybe_partial[0]) for _ in range(self._world_size)]
dist.all_gather(tensor_list, grad_maybe_partial[0], group=self.process_group)
dist.all_gather(tensor_list, grad_maybe_partial[0], group=self.dp_process_group)
grad_flat = torch.cat(tensor_list, dim=0)
else:
grad_flat = torch.cat(grad_maybe_partial, dim=0)
Expand Down Expand Up @@ -420,7 +420,7 @@ class LowLevelOptStrategy(LowLevelOptStrategyBase):
def __init__(
self,
param_group: Dict[str, Any], # from optimizer.param_groups
process_group: Optional[ProcessGroup] = None, # the dp pg for comm
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = False,
Expand All @@ -430,7 +430,7 @@ def __init__(
):
super().__init__(
param_group=param_group,
process_group=process_group,
dp_process_group=dp_process_group,
cpu_offload=cpu_offload,
partition_grad=partition_grad,
master_weights=master_weights,
Expand Down Expand Up @@ -516,7 +516,7 @@ def post_step(self):
all_splited_param = [
torch.zeros(master_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, master_param.to(device).to(self._dtype), group=self.process_group)
dist.all_gather(all_splited_param, master_param.to(device).to(self._dtype), group=self.dp_process_group)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))

# restore tmp values
Expand All @@ -535,7 +535,7 @@ def __init__(
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
process_group: Optional[ProcessGroup] = None, # the dp pg for comm
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
master_weights: bool = True, # master weights
):
for param in param_group["params"]:
Expand All @@ -544,7 +544,7 @@ def __init__(

super().__init__(
param_group=param_group,
process_group=process_group,
dp_process_group=dp_process_group,
cpu_offload=cpu_offload,
partition_grad=partition_grad,
master_weights=master_weights,
Expand All @@ -556,6 +556,6 @@ def __init__(
# def get_param_grad(self, param): # TODO @botbw: discuss whether it's intuitive to return grad of divided of full moe tensor
# moe_partial_grad = super().get_param_grad(param)
# moe_grad_list = [torch.empty_like(moe_partial_grad) for _ in range(self._world_size)]
# dist.all_gather(moe_grad_list, moe_partial_grad, group=self.process_group)
# dist.all_gather(moe_grad_list, moe_partial_grad, group=self.dp_process_group)
# moe_grad = torch.cat(moe_grad_list, dim=0).reshape(param.shape[0] * self._world_size, *param.shape[1:])
# return moe_grad
4 changes: 2 additions & 2 deletions tests/test_moe/test_moe_zero_fwd_bwd_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.
strategies = [
LowLevelOptStrategy(
param_group=zero_optimizer.param_groups[0],
process_group=plugin.global_dp_group,
dp_process_group=plugin.global_dp_group,
overlap_communication=False,
partition_grad=(stage == 2),
),
MoeZeroStrategy(
param_group=zero_optimizer.param_groups[1],
process_group=plugin.moe_dp_group,
dp_process_group=plugin.moe_dp_group,
overlap_communication=True,
partition_grad=(stage == 2),
),
Expand Down

0 comments on commit a10802e

Please sign in to comment.