Skip to content

Commit

Permalink
[WIP] further refactor strategy: move grad&comm to base class
Browse files Browse the repository at this point in the history
  • Loading branch information
botbw committed Jun 10, 2024
1 parent 599a007 commit e4197a4
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 165 deletions.
30 changes: 29 additions & 1 deletion colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.zero import LowLevelZeroOptimizer
from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, MoeZeroStrategy

from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
Expand Down Expand Up @@ -320,8 +322,34 @@ def configure(
model = LowLevelZeroModel(model, self.precision)

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
# TODO @botbw: better way of doing this
# seperate param into two groups: norm and extra_moe_dp
assert (
len(optimizer.param_groups) == 1 and len(optimizer.state) == 0
), "This requires the optimizer to be uninitialized"
non_moe_params, moe_params = [], []
for param in model.parameters():
if is_moe_tensor(param):
moe_params.append(param)
else:
non_moe_params.append(param)
strategies = None
if len(moe_params) != 0:
print(f"{len(moe_params)=}, {len(non_moe_params)=}")
prev_param_group = optimizer.param_groups[0]
prev_param_group.pop("params")
optimizer.param_groups = [
{"params": non_moe_params, **prev_param_group},
{"params": moe_params, **prev_param_group},
]
strategies = [
LowLevelOptStrategy(param_group=optimizer.param_groups[0], **self.zero_optim_kwargs),
MoeZeroStrategy(param_group=optimizer.param_groups[1], **self.zero_optim_kwargs),
]
if "moe_extra_dp_process_group" in self.zero_optim_kwargs:
self.zero_optim_kwargs.pop("moe_extra_dp_process_group")
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer, **self.zero_optim_kwargs, verbose=self.verbose
optimizer, group_strategies=strategies, **self.zero_optim_kwargs, verbose=self.verbose
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
Expand Down
17 changes: 17 additions & 0 deletions colossalai/zero/low_level/bookkeeping/gradient_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,23 @@ def get_working_grad_by_param_id(self, param_id) -> Tensor:

raise KeyError(f"Working gradient for param_id {param_id} not found.")

def get_master_grad_by_param_id(self, param_id) -> List[Tensor]:
"""
Return the working gradient for the specified parameter.
Args:
param_id (int): The index of the parameter.
Returns:
List[Tensor]: The the working gradient slices for the specified param_id.
"""

for group in self._grads_of_params.values():
if param_id in group.keys():
return group[param_id]

raise KeyError(f"Working gradient for param_id {param_id} not found.")

def reset_grads_by_group_id(self, group_id: int):
self._grads_of_params[group_id] = dict()

Expand Down
69 changes: 31 additions & 38 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def __init__(
self,
optimizer: Optimizer,
group_strategies: List[LowLevelOptStrategyBase] = None,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
Expand All @@ -107,10 +108,6 @@ def __init__(
self._logger = get_dist_logger()
self._verbose = verbose

assert len(self.optim.param_groups) == 1
self.moe_extra_dp_pg = None # TODO @botbw: refactor this
self._world_size = dist.get_world_size(group=dp_process_group) # TODO @botbw: refactor this
self._local_rank = dist.get_rank(group=dp_process_group) # TODO @botbw: refactor this
# gradient clipping
self._clip_grad_norm = clip_grad_norm

Expand All @@ -124,18 +121,23 @@ def __init__(
# check argument conflict
self._sanity_checks()

self._group_strategies = [
LowLevelOptStrategy(
param_group=self.optim.param_groups[0],
reduce_bucket_size=reduce_bucket_size,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
partition_grads=partition_grad,
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
master_weights=master_weights,
)
]
if len(self.optim.param_groups) == 1 and group_strategies is None:
group_strategies = [
LowLevelOptStrategy(
param_group=self.optim.param_groups[0],
reduce_bucket_size=reduce_bucket_size,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
partition_grads=partition_grad,
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
master_weights=master_weights,
)
]
elif len(self.optim.param_groups) > 1 and group_strategies is None:
raise ValueError("group_strategies must be provided when the optimizer has multiple param groups")

self._group_strategies = group_strategies

# initialize mixed precision mixin
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
Expand Down Expand Up @@ -198,34 +200,24 @@ def load_state_dict(self, state_dict: Dict):
for strategy in self._group_strategies:
strategy.scatter_optim_state(self.optim.state)

# def update_master_params(self, model: nn.Module) -> None:
# """Update master params from working params

# Args:
# model (nn.Module): The model to update master params
# """
# for p in model.parameters():
# p_id = id(p)
# if p_id in self._group_strategies[0]._param_store.working_to_master_param:
# master_param = self._group_strategies[0]._param_store.working_to_master_param[p_id]
# padding_size = self._group_strategies[0]._param_store.get_param_padding_size(p)
# working_param = p.data.view(-1)
# if padding_size > 0:
# working_param = torch.nn.functional.pad(working_param, [0, padding_size])
# if self.moe_extra_dp_pg is not None and is_moe_tensor(p):
# master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
# else:
# master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
# if hasattr(self, "master_moe_params"):
# for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
# master_moe_param.copy_(working_moe_param)
def update_master_params(self, model: nn.Module) -> None:
"""Update master params from working params
Args:
model (nn.Module): The model to update master params
"""
# TODO @botbw: not sure if we should access params directly
all_working_params_strategy = []
for stategy in self._group_strategies:
all_working_params_strategy.extend(stategy.working_params)
stategy.update_master_params()
assert set(all_working_params_strategy) == set(model.parameters()), "model parameters should be the same"

def step(self, closure=None):
assert closure is None, "closure is not supported by step()"
if not self.require_grad_sync:
return

# TODO @botbw: implement this
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
if self._verbose:
self._logger.info(f"Found overflow. Skip step")
Expand All @@ -234,6 +226,7 @@ def step(self, closure=None):
strategy.zero_grad()
return

# TODO @botbw can be further refactored
grad_partition_groups = []
norm_groups = []
for strategy in self._group_strategies:
Expand Down
Loading

0 comments on commit e4197a4

Please sign in to comment.