Skip to content

Commit

Permalink
[gemini] support gradient accumulation (hpcaitech#4869)
Browse files Browse the repository at this point in the history
* add test

* fix no_sync bug in low level zero plugin

* fix test

* add argument for grad accum

* add grad accum in backward hook for gemini

* finish implementation, rewrite tests

* fix test

* skip stuck model in low level zero test

* update doc

* optimize communication & fix gradient checkpoint

* modify doc

* cleaning codes

* update cpu adam fp16 case
  • Loading branch information
Fridge003 authored and flybird11111 committed Oct 18, 2023
1 parent b47dfb3 commit d8e3f1a
Show file tree
Hide file tree
Showing 11 changed files with 283 additions and 10 deletions.
5 changes: 4 additions & 1 deletion colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ class GeminiPlugin(DPPluginBase):
chunk_config_dict (dict, optional): chunk configuration dictionary.
chunk_init_device (torch.device, optional): device to initialize the chunk.
placement_policy (str, optional): "static" and "auto". Defaults to "static".
enable_gradient_accumulation (bool, optional): Whether to enable gradient accumulation. When set to True, gradient will be stored after doing backward pass. Defaults to False.
shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
Expand All @@ -257,7 +258,7 @@ class GeminiPlugin(DPPluginBase):
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
master_weights (bool, optional): master weights. Defaults to True.
master_weights (bool, optional): Whether to keep fp32 master parameter weights in optimizer. Defaults to True.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
Expand Down Expand Up @@ -291,6 +292,7 @@ def __init__(
chunk_config_dict: Optional[dict] = None,
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
Expand Down Expand Up @@ -323,6 +325,7 @@ def __init__(
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
placement_policy=placement_policy,
enable_gradient_accumulation=enable_gradient_accumulation,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
Expand Down
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,4 +335,4 @@ def get_checkpoint_io(self) -> CheckpointIO:

def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(optimizer, LowLevelZeroOptimizer)
return optimizer.optim.no_sync()
return optimizer.no_sync()
15 changes: 15 additions & 0 deletions colossalai/zero/gemini/chunk/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,21 @@ def copy_tensor_to_chunk_slice(
if update_ptr:
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)

def add_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
"""
Add data slice to the memory space indexed by the input tensor in the chunk.
Only used when accumulating gradient chunks.
Args:
tensor (torch.Tensor): the tensor used to retrieve meta information
data_slice (torch.Tensor): the tensor to be added to the chunk
"""
# sanity check
assert self.is_gathered

tensor_info = self.tensors_info[tensor]
self.cuda_global_chunk[tensor_info.offset : tensor_info.end].add_(data_slice.data.flatten())

def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload."""
if self.keep_gathered:
Expand Down
36 changes: 35 additions & 1 deletion colossalai/zero/gemini/chunk/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.distributed as dist
from torch.distributed import ProcessGroup

from colossalai.utils import get_current_device
from colossalai.utils import free_storage, get_current_device

from .chunk import Chunk, ChunkFullError, TensorState

Expand Down Expand Up @@ -255,3 +255,37 @@ def init_grad_chunk(self, chunk: Chunk) -> Chunk:
self.accessed_chunks.add(grad_chunk)
self.accessed_mem += grad_chunk.chunk_mem
return grad_chunk

def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk:
"""Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction."""

assert chunk.grad_chunk is not None

# Make a backup for gradient accumulated before.
# Here backup gradients should be multiplied, since it will be divided after gradient reduction.
if chunk.grad_chunk.is_gathered:
accumulated_grad = chunk.grad_chunk.cuda_global_chunk.clone().detach().mul_(chunk.pg_size)
accumulated_grad_gathered = True
else:
if chunk.grad_chunk.cuda_shard is not None:
accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size)
else:
accumulated_grad = (
chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size)
)
accumulated_grad_gathered = False

# Reset grad_chunk, and chunk.grad_chunk will be accessed.
grad_chunk = self.init_grad_chunk(chunk)
grad_chunk.cuda_global_chunk.zero_()

# Add backup gradients to grad_chunk.
if accumulated_grad_gathered:
grad_chunk.cuda_global_chunk.add_(accumulated_grad)
else:
grad_chunk.cuda_global_chunk[grad_chunk.shard_begin : grad_chunk.shard_end].add_(accumulated_grad)

# Release accumulated_grad
free_storage(accumulated_grad)

return grad_chunk
25 changes: 22 additions & 3 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
chunk_config_dict: Optional[dict] = None,
chunk_init_device: torch.device = torch.device("cpu"),
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
Expand Down Expand Up @@ -119,6 +120,11 @@ def __init__(
self.reuse_fp16_chunk = master_weights
self.master_weights = master_weights

self.enable_gradient_accumulation = enable_gradient_accumulation
if self.enable_gradient_accumulation:
self.reuse_fp16_chunk = False
self.accumulating_grads = False # Whether model is accumulating gradients

self._logger = get_dist_logger()

if self.gemini_manager._premade_memstats_:
Expand Down Expand Up @@ -298,6 +304,8 @@ def _post_backward(self):
f"{error_str}",
)
self._setup_grads_ptr()
if self.enable_gradient_accumulation and not self.accumulating_grads:
self.accumulating_grads = True # Turn on the state of gradient accumulation.
self._logger.debug(
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
)
Expand Down Expand Up @@ -327,7 +335,15 @@ def grad_handle(self, p, grad):
)
grad_chunk = chunk
if not self.reuse_fp16_chunk:
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
if not self.accumulating_grads:
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
else:
assert chunk.grad_chunk is not None
if chunk.grad_chunk not in self.chunk_manager.accessed_chunks:
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
else:
grad_chunk = chunk.grad_chunk

# hold -> compute -> hold after bwd
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
grad_chunk.tensor_trans_state(p, TensorState.HOLD_AFTER_BWD)
Expand All @@ -336,7 +352,10 @@ def grad_handle(self, p, grad):
chunk.tensor_trans_state(p, TensorState.HOLD)

grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
if not self.accumulating_grads:
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
else:
grad_chunk.add_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
if reduced:
if not self.reuse_fp16_chunk:
Expand All @@ -354,7 +373,7 @@ def grad_handle(self, p, grad):
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
if not self.master_weights:
if not (self.master_weights) or (self.enable_gradient_accumulation):
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad

Expand Down
1 change: 1 addition & 0 deletions colossalai/zero/gemini/gemini_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def step(self, *args, **kwargs):
self.zero_grad()
if self.module.master_weights:
self._update_fp16_params()
self.module.accumulating_grads = False
return ret

def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
Expand Down
29 changes: 28 additions & 1 deletion docs/source/en/features/gradient_accumulation_with_booster.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Gradient Accumulation

Author: [Mingyan Jiang](https://github.com/jiangmingyan)
Author: [Mingyan Jiang](https://github.com/jiangmingyan), [Baizhou Zhang](https://github.com/Fridge003)

**Prerequisite**
- [Training Booster](../basics/booster_api.md)
Expand Down Expand Up @@ -126,6 +126,7 @@ for idx, (img, label) in enumerate(train_dataloader):

```


### Step 6. Invoke Training Scripts
To verify gradient accumulation, we can just check the change of parameter values. When gradient accumulation is set, parameters are only updated in the last step. You can run the script using this command:
```shell
Expand All @@ -142,4 +143,30 @@ iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0
iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=<SliceBackward0>)
```


## Gradient Accumulation on GeminiPlugin

Currently the plugins supporting `no_sync()` method include `TorchDDPPlugin` and `LowLevelZeroPlugin` set to stage 1. `GeminiPlugin` doesn't support `no_sync()` method, but it can enable synchronized gradient accumulation in a torch-like way.

To enable gradient accumulation feature, the argument `enable_gradient_accumulation` should be set to `True` when initializing `GeminiPlugin`. Following is the pseudocode snippet of enabling gradient accumulation for `GeminiPlugin`:
<!--- doc-test-ignore-start -->
```python
...
plugin = GeminiPlugin(..., enable_gradient_accumulation=True)
booster = Booster(plugin=plugin)
...

...
for idx, (input, label) in enumerate(train_dataloader):
output = gemini_model(input.cuda())
train_loss = criterion(output, label.cuda())
train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, gemini_optimizer)

if idx % (GRADIENT_ACCUMULATION - 1) == 0:
gemini_optimizer.step() # zero_grad is automatically done
...
```
<!--- doc-test-ignore-end -->

<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 gradient_accumulation_with_booster.py -->
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# 梯度累积

作者: [Mingyan Jiang](https://github.com/jiangmingyan)
作者: [Mingyan Jiang](https://github.com/jiangmingyan), [Baizhou Zhang](https://github.com/Fridge003)

**前置教程**
- [训练中使用Booster](../basics/booster_api.md)
Expand Down Expand Up @@ -93,6 +93,7 @@ model, optimizer, criterion, train_dataloader, _ = booster.boost(model=model,
dataloader=train_dataloader)
```


### 步骤 5. 使用booster训练
使用booster构建一个普通的训练循环,验证梯度累积。 `param_by_iter` 记录分布训练的信息。
```python
Expand Down Expand Up @@ -144,4 +145,29 @@ iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0
iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=<SliceBackward0>)
```

## 在Gemini插件中使用梯度累积

目前支持`no_sync()`方法的插件包括 `TorchDDPPlugin``LowLevelZeroPlugin`(需要设置参数`stage`为1). `GeminiPlugin` 不支持 `no_sync()` 方法, 但是它可以通过和`pytorch`类似的方式来使用同步的梯度累积。

为了开启梯度累积功能,在初始化`GeminiPlugin`的时候需要将参数`enable_gradient_accumulation`设置为`True`。以下是 `GeminiPlugin` 进行梯度累积的伪代码片段:
<!--- doc-test-ignore-start -->
```python
...
plugin = GeminiPlugin(..., enable_gradient_accumulation=True)
booster = Booster(plugin=plugin)
...

...
for idx, (input, label) in enumerate(train_dataloader):
output = gemini_model(input.cuda())
train_loss = criterion(output, label.cuda())
train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, gemini_optimizer)

if idx % (GRADIENT_ACCUMULATION - 1) == 0:
gemini_optimizer.step() # zero_grad is automatically done
...
```
<!--- doc-test-ignore-end -->

<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 gradient_accumulation_with_booster.py -->
1 change: 0 additions & 1 deletion tests/components_to_test/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def bert_model_builder(checkpoint: bool = False):
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
)
print("building BertForSequenceClassification model")

# adapting huggingface BertForSequenceClassification for single unittest calling interface
class ModelAdaptor(BertForSequenceClassification):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_booster/test_plugin/test_low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
_AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"]
# These models have no parameters
_LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"]
# These models will cause stuck, to be fixed
_STUCK_MODELS = ["transformers_albert_for_multiple_choice"]


def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
Expand Down Expand Up @@ -53,7 +55,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
"""
passed_models = []
failed_info = {} # (model_name, error) pair
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
skipped_models = []

for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
Expand Down
Loading

0 comments on commit d8e3f1a

Please sign in to comment.