Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fp8] merge #6023

Merged
merged 152 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
152 commits
Select commit Hold shift + click to select a range
82aecd6
add SimPO
YeAnbang Jun 24, 2024
4b59d87
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into main
YeAnbang Jun 24, 2024
0b2d627
fix dataloader
YeAnbang Jun 24, 2024
f3de5a0
remove debug code
YeAnbang Jun 24, 2024
c8d1b4a
add orpo
YeAnbang Jun 27, 2024
8aad064
fix style
YeAnbang Jun 27, 2024
384c640
fix colossalai, transformers version
YeAnbang Jun 27, 2024
afa5306
fix colossalai, transformers version
YeAnbang Jun 27, 2024
b117274
fix colossalai, transformers version
YeAnbang Jun 27, 2024
e752776
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into r…
YeAnbang Jun 28, 2024
a8af6cc
fix torch colossalai version
YeAnbang Jun 28, 2024
ff53520
update transformers version
YeAnbang Jun 28, 2024
3420921
[shardformer] DeepseekMoE support (#5871)
Hz188 Jul 5, 2024
8ec24b6
[Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Edenzzzz Jul 5, 2024
cba2052
[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
LRY89757 Jul 8, 2024
66abf1c
[HotFix] CI,import,requirements-test for #5838 (#5892)
LRY89757 Jul 8, 2024
fbf33ec
[Feature] Enable PP + SP for llama (#5868)
Edenzzzz Jul 9, 2024
16f3451
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into r…
YeAnbang Jul 10, 2024
669849d
[ShardFormer] Add Ulysses Sequence Parallelism support for Command-R,…
GuangyaoZhang Jul 10, 2024
d888c37
add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Sup…
YeAnbang Jul 10, 2024
f6ef5c3
fix style
YeAnbang Jul 10, 2024
33f1520
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into r…
YeAnbang Jul 10, 2024
8a9721b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 10, 2024
dd9e1cd
Merge pull request #5850 from hpcaitech/rlhf_SimPO
YeAnbang Jul 11, 2024
e7a8634
fix eval
YeAnbang Jul 11, 2024
115c4cc
hotfix citation
YeAnbang Jul 11, 2024
c068ef0
[zero] support all-gather overlap (#5898)
ver217 Jul 11, 2024
b3594d4
fix orpo cross entropy loss
YeAnbang Jul 15, 2024
45c49dd
[Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
stephankoe Jul 15, 2024
1c961b2
[ShardFormer] fix qwen2 sp (#5903)
GuangyaoZhang Jul 15, 2024
d8bf7e0
Merge pull request #5901 from hpcaitech/colossalchat
YeAnbang Jul 16, 2024
2e28c79
[compatibility] support torch 2.2 (#5875)
GuangyaoZhang Jul 4, 2024
530283d
fix object_to_tensor usage when torch>=2.3.0 (#5820)
kurisusnowdeng Jul 4, 2024
27a72f0
[misc] support torch2.3 (#5893)
ver217 Jul 11, 2024
73494de
[release] update version (#5912)
ver217 Jul 17, 2024
e861279
[plugin] support all-gather overlap for hybrid parallel (#5919)
ver217 Jul 18, 2024
09d5ffc
add kto
YeAnbang Jul 18, 2024
845ea72
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into kto
YeAnbang Jul 18, 2024
544b7a3
fix style, add kto data sample
YeAnbang Jul 18, 2024
8cc8f64
[Examples] Add lazy init to OPT and GPT examples (#5924)
Edenzzzz Jul 19, 2024
f585d4e
[ColossalChat] Hotfix for ColossalChat (#5910)
TongLi3701 Jul 19, 2024
d08c99b
Merge branch 'main' into kto
TongLi3701 Jul 19, 2024
d49550f
refactor tokenization
YeAnbang Jul 19, 2024
150505c
Merge branch 'kto' of https://github.com/hpcaitech/ColossalAI into kto
YeAnbang Jul 19, 2024
4ec17a7
[FIX BUG] UnboundLocalError: cannot access local variable 'default_co…
zhurunhua Jul 21, 2024
c5f582f
fix test data
YeAnbang Jul 22, 2024
12fe8b5
refactor evaluation
YeAnbang Jul 22, 2024
b0e15d5
remove real data path
YeAnbang Jul 22, 2024
9688e19
remove real data path
YeAnbang Jul 22, 2024
a521ffc
Add n_fused as an input from native_module (#5894)
insujang Jul 23, 2024
5fb958c
[FIX BUG] convert env param to int in (#5934)
flymin Jul 24, 2024
2069472
[Hotfix] Fix ZeRO typo #5936
Edenzzzz Jul 25, 2024
ad35a98
[Feature] Add a switch to control whether the model checkpoint needs …
zhurunhua Jul 26, 2024
8a3ff4f
fix style
YeAnbang Jul 26, 2024
de1bf08
fix style
YeAnbang Jul 26, 2024
6fd9e86
fix style
YeAnbang Jul 29, 2024
c8332b9
Merge pull request #5922 from hpcaitech/kto
YeAnbang Jul 29, 2024
9664b1b
[shardformer] hotfix attn mask (#5945)
ver217 Jul 29, 2024
7b38964
[shardformer] hotfix attn mask (#5947)
ver217 Jul 29, 2024
bcf0181
[Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
LRY89757 Jul 30, 2024
0608921
[zero] hotfix update master params (#5951)
ver217 Jul 30, 2024
09c5f72
[release] update version (#5952)
ver217 Jul 31, 2024
30f4e31
[Chat] Fix lora (#5946)
YeAnbang Jul 31, 2024
66fbf2e
Update README.md (#5958)
YeAnbang Jul 31, 2024
1aeb5e8
[hotfix] Remove unused plan section (#5957)
TongLi3701 Jul 31, 2024
f9b6fcf
[test] add mixtral for sequence classification
botbw Jul 2, 2024
0b76b57
[test] add mixtral transformer test
botbw Jul 2, 2024
8ae8525
[moe] fix plugin
botbw Jul 2, 2024
a249e71
[test] mixtra pp shard test
botbw Jul 4, 2024
0fad23c
[chore] handle non member group
botbw Jul 5, 2024
46c069b
[zero] solve hang
botbw Jul 5, 2024
37443cc
[test] pass mixtral shardformer test
botbw Jul 8, 2024
b5bfeb2
[moe] implement transit between non moe tp and ep
botbw Jul 8, 2024
13b48ac
[zero] solve hang
botbw Jul 9, 2024
fe24789
[misc] solve booster hang by rename the variable
Hz188 Jul 9, 2024
5ed5e8c
solve hang when parallel mode = pp + dp
Hz188 Jul 11, 2024
e28e053
[moe] implement submesh initialization
botbw Jul 11, 2024
9b9b76b
[moe] add mixtral dp grad scaling when not all experts are activated
botbw Jul 12, 2024
014faf6
[chore] manually revert unintended commit
botbw Jul 12, 2024
8dbb868
[chore] trivial fix
botbw Jul 12, 2024
102b784
[chore] arg pass & remove drop token
botbw Jul 12, 2024
0b5bbe9
[test] add mixtral modelling test
botbw Jul 15, 2024
dc583aa
[moe] implement tp
botbw Jul 16, 2024
74eccac
[moe] test deepseek
botbw Jul 16, 2024
3e2b613
[moe] clean legacy code
botbw Jul 16, 2024
404b16f
[Feature] MoE Ulysses Support (#5918)
Hz188 Jul 18, 2024
09d6280
[chore] minor fix
botbw Jul 18, 2024
877d94b
[moe] init moe plugin comm setting with sp
botbw Jul 18, 2024
2cddeac
moe sp + ep bug fix
Hz188 Jul 18, 2024
7077d38
[moe] finalize test (no pp)
botbw Jul 18, 2024
803878b
[moe] full test for deepseek and mixtral (pp + sp to fix)
botbw Jul 19, 2024
46037c2
[chore] minor fix after rebase
botbw Jul 19, 2024
52d346f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2024
70c9924
[chore] solve moe ckpt test failure and some other arg pass failure
botbw Jul 22, 2024
74b03de
[moe] remove ops
botbw Jul 22, 2024
067e18f
[test] fix test: test_zero1_2
botbw Jul 22, 2024
96d0fbc
[bug] fix: somehow logger hangs the program
botbw Jul 23, 2024
b2952a5
[moe] deepseek moe sp support
Hz188 Jul 23, 2024
6c39f0b
[test] add check
botbw Jul 23, 2024
c3dc9b4
[deepseek] replace attn (a workaround for bug in transformers)
botbw Jul 23, 2024
59bcf56
[misc] skip redunant test
Hz188 Jul 24, 2024
034020b
[misc] remove debug/print code
Hz188 Jul 24, 2024
cb01c0d
[moe] refactor mesh assignment
botbw Jul 25, 2024
5b4c123
Revert "[moe] implement submesh initialization"
botbw Jul 25, 2024
606b089
[chore] change moe_pg_mesh to private
botbw Jul 25, 2024
12d043c
[misc] remove incompatible test config
Hz188 Jul 25, 2024
70793ce
[misc] fix ci failure: change default value to false in moe plugin
Hz188 Jul 25, 2024
7e737df
[misc] remove useless condition
Hz188 Jul 25, 2024
f7c5485
[chore] docstring
botbw Jul 25, 2024
7bedd03
[moe] remove force_overlap_comm flag and add warning instead
botbw Jul 25, 2024
65daa87
[doc] add MoeHybridParallelPlugin docstring
botbw Jul 26, 2024
d1d1ab8
[moe] solve dp axis issue
botbw Jul 26, 2024
62cdac6
[chore] remove redundant test case, print string & reduce test tokens
botbw Jul 30, 2024
19d1510
[feat] Dist Loader for Eval (#5950)
TongLi3701 Aug 2, 2024
75c9636
[lora] lora support hybrid parallel plugin (#5956)
wangbluo Aug 2, 2024
0b2d55c
Support overall loss, update KTO logging
YeAnbang Aug 2, 2024
fe71917
Merge pull request #5962 from hpcaitech/colossalchat
YeAnbang Aug 2, 2024
9179d40
[Docs] clarify launch port
Edenzzzz Aug 7, 2024
ad3fa4f
[Hotfix] README link (#5966)
TongLi3701 Aug 8, 2024
b4d2377
[Hotfix] Avoid fused RMSnorm import error without apex (#5985)
Edenzzzz Aug 9, 2024
ed97d3a
[Chat] fix readme (#5989)
YeAnbang Aug 12, 2024
ceb1e26
fix sync condition (#6000)
TongLi3701 Aug 14, 2024
406f984
[plugin] add cast inputs option for zero (#6003)
ver217 Aug 15, 2024
4dd0399
[pre-commit.ci] pre-commit autoupdate (#5995)
pre-commit-ci[bot] Aug 15, 2024
887d2d5
[misc] Bypass the huggingface bug to solve the mask mismatch problem …
Hz188 Aug 15, 2024
f5c84af
[Feature] Zigzag Ring attention (#5905)
Edenzzzz Aug 16, 2024
26493b9
[misc] update compatibility (#6008)
ver217 Aug 16, 2024
4cf79fa
merge
wangbluo Aug 17, 2024
81272e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2024
02636c5
fix the merge
wangbluo Aug 19, 2024
52289e4
Merge branch 'fp8_merge' of https://github.com/wangbluo/ColossalAI in…
wangbluo Aug 19, 2024
1a5847e
fix the merge
wangbluo Aug 19, 2024
f1c3266
overlap kv comm with output rescale (#6017)
Edenzzzz Aug 19, 2024
3353042
fix the merge
wangbluo Aug 19, 2024
64aad96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2024
4c82bfc
fix the merge
wangbluo Aug 19, 2024
0d8e82a
Merge branch 'fp8_merge' of https://github.com/wangbluo/ColossalAI in…
wangbluo Aug 19, 2024
12b4401
fix
wangbluo Aug 19, 2024
2eb3683
fix
wangbluo Aug 19, 2024
88b3f06
fix the merge
wangbluo Aug 19, 2024
1f703e0
fix
wangbluo Aug 19, 2024
dcc44aa
[misc] Use dist logger in plugins (#6011)
Edenzzzz Aug 20, 2024
5382311
fix
wangbluo Aug 20, 2024
f7acfa1
fix
wangbluo Aug 20, 2024
2ee6235
fix
wangbluo Aug 20, 2024
2e4cbe3
fix
wangbluo Aug 20, 2024
2d362ac
fix merge
wangbluo Aug 20, 2024
eb5ba40
fix the merge
wangbluo Aug 21, 2024
193030f
fix
wangbluo Aug 21, 2024
6aface9
fix
wangbluo Aug 21, 2024
698c8b9
fix
wangbluo Aug 21, 2024
8b8e282
fix
wangbluo Aug 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .compatibility
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
2.1.0-12.1.0
2.2.2-12.1.0
2.3.0-12.1.0
2.4.0-12.4.1
4 changes: 2 additions & 2 deletions .cuda_ext.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
"cuda_image": "hpcaitech/cuda-conda:12.1"
},
{
"torch_command": "pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118",
"cuda_image": "hpcaitech/cuda-conda:11.8"
"torch_command": "pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124",
"cuda_image": "hpcaitech/cuda-conda:12.4"
}
]
}
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ jobs:
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v -e .
pip install -r requirements/requirements-test.txt
pip install --no-cache-dir -r requirements/requirements-test.txt

- name: Store Colossal-AI Cache
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
BUILD_EXT=1 pip install -v -e .
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
pip install -r requirements/requirements-test.txt
pip install --no-cache-dir -r requirements/requirements-test.txt

- name: Unit Testing
if: steps.check-avai.outputs.avai == 'true'
Expand Down
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ repos:
hooks:
- id: isort
name: sort all imports (python)
args: ["--profile", "black"] # avoid conflict with black

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2
rev: 24.8.0
hooks:
- id: black
name: black formatter
Expand Down
1 change: 1 addition & 0 deletions applications/ColossalChat/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ examples/training_scripts/wandb
examples/training_scripts/output

examples/awesome-chatgpt-prompts/
examples/inference/round.txt
temp/

# ColossalChat
Expand Down
2 changes: 1 addition & 1 deletion applications/ColossalChat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ cd $COLOSSAL_AI_ROOT
BUILD_EXT=1 pip install .

# Install ColossalChat
cd $COLOSSAL_AI_ROOT/applications/Chat
cd $COLOSSAL_AI_ROOT/applications/ColossalChat
pip install .
```

Expand Down
21 changes: 18 additions & 3 deletions applications/ColossalChat/coati/dataset/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def tokenize_sft(

messages = data_point["messages"]
template = deepcopy(conversation_template)

if messages[0]["from"] == "system":
template.system_message = str(messages[0]["content"])
messages.pop(0)
template.messages = []
for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]:
Expand Down Expand Up @@ -148,11 +152,14 @@ def tokenize_prompt(
template = deepcopy(conversation_template)
template.messages = []

if messages[0]["from"] == "system":
template.system_message = str(messages[0]["content"])
messages.pop(0)

for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{messages}"
f"Message should iterate between user and assistant and starts with a line from the user. Got the following data:\n{messages}"
)
template.append_message(mess["from"], mess["content"])

Expand All @@ -162,7 +169,7 @@ def tokenize_prompt(
template.messages = template.messages[:-1]

# Prepare data
prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True)
prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]

if tokenizer.bos_token_id is not None:
Expand Down Expand Up @@ -225,6 +232,10 @@ def tokenize_rlhf(
template = deepcopy(conversation_template)
template.clear()

if context[0]["from"] == "system":
template.system_message = str(context[0]["content"])
context.pop(0)

for idx, mess in enumerate(context):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
Expand Down Expand Up @@ -345,6 +356,10 @@ def tokenize_kto(
template = deepcopy(conversation_template)
template.clear()

if prompt[0]["from"] == "system":
template.system_message = str(prompt[0]["content"])
prompt.pop(0)

if prompt[0].get("from", None) != "user":
raise ValueError("conversation should start with user")
if completion.get("from", None) != "assistant":
Expand Down
16 changes: 12 additions & 4 deletions applications/ColossalChat/coati/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def forward(
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
skip = False
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
if action_mask is None:
ratio_ = (log_probs - old_log_probs).exp()
else:
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()

# note that if dropout is disabled (recommanded), ratio will always be 1.
if ratio_.mean() > self.skip_threshold:
Expand All @@ -56,7 +59,10 @@ def forward(
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2)
loss = masked_mean(loss, action_mask)
if action_mask is not None:
loss = masked_mean(loss, action_mask)
else:
loss = loss.mean(dim=1)
loss = loss.mean()
return loss, skip, ratio_.max()

Expand All @@ -81,8 +87,10 @@ def forward(
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - returns) ** 2
surr2 = (values - returns) ** 2
loss = torch.max(surr1, surr2) / torch.sum(action_mask)
loss = torch.sum(loss * action_mask)
if action_mask is not None:
loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask)
else:
loss = torch.mean(torch.max(surr1, surr2))
return 0.5 * loss


Expand Down
7 changes: 4 additions & 3 deletions applications/ColossalChat/coati/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def disable_dropout(model: torch.nn.Module):
Returns:
None
"""
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0.0
if model is not None:
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0.0
9 changes: 9 additions & 0 deletions applications/ColossalChat/coati/trainer/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
beta: float = 0.1,
gamma: float = 0.0,
length_normalization: bool = False,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
Expand All @@ -67,6 +68,7 @@ def __init__(
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.actor_loss_fn = DpoLoss(beta, gamma)
self.apply_loss_mask = apply_loss_mask
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
Expand Down Expand Up @@ -135,6 +137,10 @@ def _train(self, epoch: int):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)

batch_size = chosen_input_ids.size()[0]

actor_all_logits = self.model(
Expand Down Expand Up @@ -284,6 +290,9 @@ def _eval(self, epoch: int):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)

batch_size = chosen_input_ids.size()[0]

Expand Down
37 changes: 34 additions & 3 deletions applications/ColossalChat/coati/trainer/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Optional

import torch
import torch.distributed
import torch.distributed as dist
from coati.models.loss import KTOLoss
from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean
Expand Down Expand Up @@ -59,6 +59,7 @@ def __init__(
beta: float = 0.1,
desirable_weight: float = 1.0,
undesirable_weight: float = 1.0,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
Expand All @@ -70,6 +71,7 @@ def __init__(
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight)
self.apply_loss_mask = apply_loss_mask
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
Expand Down Expand Up @@ -134,6 +136,10 @@ def _train(self, epoch: int):
batch["kl_attention_mask"],
batch["kl_loss_mask"],
)
if not self.apply_loss_mask:
loss_mask = loss_mask.fill_(1.0)
kl_loss_mask = kl_loss_mask.fill_(1.0)

batch_size = input_ids.size()[0]

# actor logits
Expand Down Expand Up @@ -182,8 +188,28 @@ def _train(self, epoch: int):

# sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
chosen_reward_mean = chosen_rewards.mean()
chosen_rewards_list = [
torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
]
dist.all_gather(chosen_rewards_list, chosen_reward_mean)
rejected_reward_mean = rejected_rewards.mean()
rejected_rewards_list = [
torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
]
dist.all_gather(rejected_rewards_list, rejected_reward_mean)
chosen_rewards_list = [i for i in chosen_rewards_list if not i.isnan()]
rejected_rewards_list = [i for i in rejected_rewards_list if not i.isnan()]
chosen_rewards_mean = (
torch.stack(chosen_rewards_list).mean()
if len(chosen_rewards_list) > 0
else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
)
rejected_rewards_mean = (
torch.stack(rejected_rewards_list).mean()
if len(rejected_rewards_list) > 0
else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
Expand Down Expand Up @@ -256,6 +282,11 @@ def _eval(self, epoch: int):
batch["kl_attention_mask"],
batch["kl_loss_mask"],
)

if not self.apply_loss_mask:
loss_mask = loss_mask.fill_(1.0)
kl_loss_mask = kl_loss_mask.fill_(1.0)

batch_size = input_ids.size()[0]

# actor logits
Expand Down
12 changes: 12 additions & 0 deletions applications/ColossalChat/coati/trainer/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
lam: float = 0.1,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
Expand All @@ -67,6 +68,7 @@ def __init__(
self.save_dir = save_dir
self.num_train_step = 0
self.lam = lam
self.apply_loss_mask = apply_loss_mask
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
Expand Down Expand Up @@ -130,6 +132,11 @@ def _train(self, epoch: int):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)

if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)

batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
Expand Down Expand Up @@ -263,6 +270,11 @@ def _eval(self, epoch: int):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)

if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)

batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
Expand Down
12 changes: 10 additions & 2 deletions applications/ColossalChat/coati/trainer/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
save_interval: int = 0,
save_dir: str = None,
Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self.save_interval = save_interval
self.apply_loss_mask = apply_loss_mask
self.coordinator = coordinator
self.actor_save_dir = os.path.join(save_dir, "actor")
self.critic_save_dir = os.path.join(save_dir, "critic")
Expand Down Expand Up @@ -229,7 +231,10 @@ def _training_step(self, experience: Experience):
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)

actor_loss, to_skip, max_ratio = self.actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask if self.apply_loss_mask else None,
)
actor_loss = (1 - self.ptx_coef) * actor_loss
if not to_skip:
Expand All @@ -249,7 +254,10 @@ def _training_step(self, experience: Experience):
input_ids=experience.sequences, attention_mask=experience.attention_mask
) # [batch size, prompt_length + response_length]
critic_loss = self.critic_loss_fn(
values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask
values[:, -num_actions:],
experience.values,
experience.advantages,
action_mask=experience.action_mask if self.apply_loss_mask else None,
)
critic_loss = critic_loss * self.vf_coef
self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim)
Expand Down
Loading
Loading