Skip to content

Commit

Permalink
[ColossalChat] Add PP support (hpcaitech#6001)
Browse files Browse the repository at this point in the history
* support pp training

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update rm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test case

* fix

* change to 4

* fix eval

* test

* add pp

* hotfix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support pp training

* update rm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test case

* fix

* change to 4

* fix eval

* test

* add pp

* hotfix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* skip pp eval

* update all reduce

* update sft

* update ignore

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update no cache

* add eval

* remove fi

* remove debug

* remove parentheses to avoid warning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "add eval"

This reverts commit 3ab2f6f.

* add all reduce

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
TongLi3701 and pre-commit-ci[bot] authored Aug 21, 2024
1 parent 0d3b0bd commit 39e2597
Show file tree
Hide file tree
Showing 16 changed files with 243 additions and 117 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/run_chatgpt_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,18 @@ jobs:

- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v -e .
BUILD_EXT=1 pip install --no-cache-dir -v -e .
- name: Install ChatGPT
run: |
cd applications/ColossalChat
pip install -v .
pip install --no-cache-dir -v .
export BUILD_EXT=1
pip install -r examples/requirements.txt
pip install --no-cache-dir -r examples/requirements.txt
- name: Install Transformers
run: |
pip install transformers==4.36.2
pip install --no-cache-dir transformers==4.36.2
- name: Execute Examples
run: |
Expand Down
6 changes: 6 additions & 0 deletions applications/ColossalChat/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,9 @@ applications/ColossalChat/sft_data
applications/ColossalChat/prompt_data
applications/ColossalChat/preference_data
applications/ColossalChat/temp

# Testing data
/kto_data/
/preference_data/
/prompt_data/
/sft_data/
4 changes: 3 additions & 1 deletion applications/ColossalChat/coati/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from coati.experience_maker import Experience
from torch.optim import Optimizer

from colossalai.booster import Booster
from colossalai.booster import Booster, Plugin

from .utils import is_rank_0

Expand All @@ -38,13 +38,15 @@ def __init__(
max_epochs: int,
model: nn.Module,
optimizer: Optimizer,
plugin: Plugin,
start_epoch: int = 0,
) -> None:
super().__init__()
self.booster = booster
self.max_epochs = max_epochs
self.model = model
self.optimizer = optimizer
self.plugin = plugin
self.start_epoch = start_epoch

@abstractmethod
Expand Down
7 changes: 5 additions & 2 deletions applications/ColossalChat/coati/trainer/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tqdm import trange
from transformers import PreTrainedTokenizerBase

from colossalai.booster import Booster
from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device

Expand Down Expand Up @@ -50,6 +50,7 @@ def __init__(
ref_model: Any,
booster: Booster,
actor_optim: Optimizer,
plugin: Plugin,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
Expand All @@ -63,7 +64,9 @@ def __init__(
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
super().__init__(
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
)
self.ref_model = ref_model
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
Expand Down
7 changes: 5 additions & 2 deletions applications/ColossalChat/coati/trainer/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tqdm import trange
from transformers import PreTrainedTokenizerBase

from colossalai.booster import Booster
from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device

Expand Down Expand Up @@ -53,6 +53,7 @@ def __init__(
ref_model: Any,
booster: Booster,
actor_optim: Optimizer,
plugin: Plugin,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
Expand All @@ -66,7 +67,9 @@ def __init__(
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
super().__init__(
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
)
self.ref_model = ref_model
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
Expand Down
7 changes: 5 additions & 2 deletions applications/ColossalChat/coati/trainer/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tqdm import trange
from transformers import PreTrainedTokenizerBase

from colossalai.booster import Booster
from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device

Expand Down Expand Up @@ -48,6 +48,7 @@ def __init__(
actor: Any,
booster: Booster,
actor_optim: Optimizer,
plugin: Plugin,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
Expand All @@ -59,7 +60,9 @@ def __init__(
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
super().__init__(
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
)
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.odds_ratio_loss_fn = OddsRatioLoss()
Expand Down
7 changes: 5 additions & 2 deletions applications/ColossalChat/coati/trainer/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase

from colossalai.booster import Booster
from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device

Expand Down Expand Up @@ -48,6 +48,7 @@ def __init__(
model: Any,
booster: Booster,
optimizer: Optimizer,
plugin: Plugin,
lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
loss_fn: Optional[Callable] = None,
Expand All @@ -59,7 +60,9 @@ def __init__(
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch)
super().__init__(
booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch
)
self.actor_scheduler = lr_scheduler
self.tokenizer = tokenizer
self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta)
Expand Down
Loading

0 comments on commit 39e2597

Please sign in to comment.