Skip to content

Commit

Permalink
[feature] modify user documentation;
Browse files Browse the repository at this point in the history
  • Loading branch information
duanjunwen committed Apr 10, 2024
1 parent 40a5528 commit 1c9bb93
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 102 deletions.
68 changes: 0 additions & 68 deletions colossalai/nn/optimizer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,71 +81,3 @@ If you wish to add an optimizer for a specific application, please follow the st


If your PR is accepted, we may invite you to put up a tutorial or blog in [ColossalAI Documentation](https://colossalai.org/).


## Optimizer

A series of optimizers have been optimized and integrated.

### Distributed Adafactor

Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details.

### API Reference

{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}

### Sample: Init with booster

```python
# ==============================
# Model Init
# ==============================
tp_model = TPModel()

# ==============================
# Optimizer Init
# ==============================
dist_optim = DistributedAdaFactor([p for n, p in tp_model.named_parameters()])

# ==============================
# Booster Init
# ==============================
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
criterion = lambda x: x.mean()
tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion)
```

### Performance
| Version | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate |
| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: |
| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 |
| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 |
| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 |
| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 |
| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - |
| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 |
| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 |
| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: |
| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 |
| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 |
| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 |
| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 |
| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - |
| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 |
| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 |
| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: |
| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 |
| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 |
| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 |
| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 |
| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - |
| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 |
| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 |
95 changes: 61 additions & 34 deletions docs/source/en/features/distributed_adafactor.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,9 @@ Author:

Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details.

### API Reference

## Performance

| Parallel strategy | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate |
| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: |
| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 |
| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 |
| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 |
| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 |
| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - |
| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 |
| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 |
| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: |
| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 |
| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 |
| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 |
| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 |
| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - |
| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 |
| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 |
| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: |
| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 |
| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 |
| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 |
| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 |
| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - |
| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 |
| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 |

{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}

## Hands-On Practice
We now demonstrate how to use Distributed Adafactor.
Expand Down Expand Up @@ -135,4 +103,63 @@ else:
dist_optim.step()
dist_optim.zero_grad()
```

## Run with booster
We highly recommend users to use booster, a simple, easy to use, and efficient interface. The Code Below is the Distributed Adafactor launched with booster.

```python
# ==============================
# Model Init
# ==============================
tp_model = TPModel()

# ==============================
# Optimizer Init
# ==============================
dist_optim = DistributedAdaFactor([p for n, p in tp_model.named_parameters()])

# ==============================
# Booster Init
# ==============================
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
criterion = lambda x: x.mean()
tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion)
```

## Performance

| Parallel strategy | iter | Float Percision | Device Nums | weight shape | Avg runtime(ms) | Avg Speed Up Rate | Best Speed Up Rate |
| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: |
| AdaFactor | 50 | float32 | 2 | [4096 , 4096] | 0.58 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.41 | 1.39 | 56.91 |
| DistAdaFactor(Col Parallel) | 50 | float32 | 2 | [4096 , 4096] | 0.61 | 0.96 | 18.69 |
| AdaFactor | 50 | float16 | 2 | [4096 , 4096] | 0.72 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.54 | 1.33 | 26.03 |
| DistAdaFactor(Row Parallel) | 50 | float16 | 2 | [4096 , 4096] | 0.67 | 1.08 | 20.55 |
| AdaFactor | 50 | bfloat16 | 2 | [4096 , 4096] | 0.72 | - | - |
| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.55 | 1.31 | 26.11 |
| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 2 | [4096 , 4096] | 0.67 | 1.07 | 21.86 |
| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: |
| AdaFactor | 50 | float32 | 4 | [4096 , 4096] | 0.57 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.38 | 1.48 | 53.99 |
| DistAdaFactor(Col Parallel) | 50 | float32 | 4 | [4096 , 4096] | 0.60 | 0.95 | 16.53 |
| AdaFactor | 50 | float16 | 4 | [4096 , 4096] | 0.70 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.50 | 1.44 | 21.98 |
| DistAdaFactor(Row Parallel) | 50 | float16 | 4 | [4096 , 4096] | 0.64 | 1.12 | 15.35 |
| AdaFactor | 50 | bfloat16 | 4 | [4096 , 4096] | 0.72 | - | - |
| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.56 | 1.29 | 25.63 |
| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 4 | [4096 , 4096] | 0.71 | 1.09 | 21.52 |
| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-----------------: | :---------------: |
| AdaFactor | 50 | float32 | 8 | [4096 , 4096] | 0.56 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.38 | 1.50 | 54.51 |
| DistAdaFactor(Col Parallel) | 50 | float32 | 8 | [4096 , 4096] | 0.91 | 0.67 | 15.68 |
| AdaFactor | 50 | float16 | 8 | [4096 , 4096] | 0.74 | - | - |
| DistAdaFactor(Col Parallel) | 50 | float16 | 8 | [4096 , 4096] | 0.84 | 0.87 | 9.21 |
| DistAdaFactor(Row Parallel) | 50 | float16 | 8 | [4096 , 4096] | 1.03 | 0.75 | 16.12 |
| AdaFactor | 50 | bfloat16 | 8 | [4096 , 4096] | 0.71 | - | - |
| DistAdaFactor(Col Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.54 | 1.31 | 27.28 |
| DistAdaFactor(Row Parallel) | 50 | bfloat16 | 8 | [4096 , 4096] | 0.73 | 1.03 | 25.01 |


<!-- doc-test-command: echo -->

0 comments on commit 1c9bb93

Please sign in to comment.