Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FP8] rebase main #5963

Merged
merged 132 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from 128 commits
Commits
Show all changes
132 commits
Select commit Hold shift + click to select a range
82aecd6
add SimPO
YeAnbang Jun 24, 2024
4b59d87
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into main
YeAnbang Jun 24, 2024
0b2d627
fix dataloader
YeAnbang Jun 24, 2024
f3de5a0
remove debug code
YeAnbang Jun 24, 2024
c8d1b4a
add orpo
YeAnbang Jun 27, 2024
8aad064
fix style
YeAnbang Jun 27, 2024
384c640
fix colossalai, transformers version
YeAnbang Jun 27, 2024
afa5306
fix colossalai, transformers version
YeAnbang Jun 27, 2024
b117274
fix colossalai, transformers version
YeAnbang Jun 27, 2024
e752776
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into r…
YeAnbang Jun 28, 2024
a8af6cc
fix torch colossalai version
YeAnbang Jun 28, 2024
ff53520
update transformers version
YeAnbang Jun 28, 2024
3420921
[shardformer] DeepseekMoE support (#5871)
Hz188 Jul 5, 2024
8ec24b6
[Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Edenzzzz Jul 5, 2024
cba2052
[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
LRY89757 Jul 8, 2024
66abf1c
[HotFix] CI,import,requirements-test for #5838 (#5892)
LRY89757 Jul 8, 2024
fbf33ec
[Feature] Enable PP + SP for llama (#5868)
Edenzzzz Jul 9, 2024
16f3451
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into r…
YeAnbang Jul 10, 2024
669849d
[ShardFormer] Add Ulysses Sequence Parallelism support for Command-R,…
GuangyaoZhang Jul 10, 2024
d888c37
add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Sup…
YeAnbang Jul 10, 2024
f6ef5c3
fix style
YeAnbang Jul 10, 2024
33f1520
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into r…
YeAnbang Jul 10, 2024
8a9721b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 10, 2024
dd9e1cd
Merge pull request #5850 from hpcaitech/rlhf_SimPO
YeAnbang Jul 11, 2024
e7a8634
fix eval
YeAnbang Jul 11, 2024
115c4cc
hotfix citation
YeAnbang Jul 11, 2024
c068ef0
[zero] support all-gather overlap (#5898)
ver217 Jul 11, 2024
b3594d4
fix orpo cross entropy loss
YeAnbang Jul 15, 2024
45c49dd
[Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
stephankoe Jul 15, 2024
1c961b2
[ShardFormer] fix qwen2 sp (#5903)
GuangyaoZhang Jul 15, 2024
d8bf7e0
Merge pull request #5901 from hpcaitech/colossalchat
YeAnbang Jul 16, 2024
2e28c79
[compatibility] support torch 2.2 (#5875)
GuangyaoZhang Jul 4, 2024
530283d
fix object_to_tensor usage when torch>=2.3.0 (#5820)
kurisusnowdeng Jul 4, 2024
27a72f0
[misc] support torch2.3 (#5893)
ver217 Jul 11, 2024
73494de
[release] update version (#5912)
ver217 Jul 17, 2024
e861279
[plugin] support all-gather overlap for hybrid parallel (#5919)
ver217 Jul 18, 2024
09d5ffc
add kto
YeAnbang Jul 18, 2024
845ea72
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into kto
YeAnbang Jul 18, 2024
544b7a3
fix style, add kto data sample
YeAnbang Jul 18, 2024
8cc8f64
[Examples] Add lazy init to OPT and GPT examples (#5924)
Edenzzzz Jul 19, 2024
f585d4e
[ColossalChat] Hotfix for ColossalChat (#5910)
TongLi3701 Jul 19, 2024
d08c99b
Merge branch 'main' into kto
TongLi3701 Jul 19, 2024
d49550f
refactor tokenization
YeAnbang Jul 19, 2024
150505c
Merge branch 'kto' of https://github.com/hpcaitech/ColossalAI into kto
YeAnbang Jul 19, 2024
4ec17a7
[FIX BUG] UnboundLocalError: cannot access local variable 'default_co…
zhurunhua Jul 21, 2024
c5f582f
fix test data
YeAnbang Jul 22, 2024
12fe8b5
refactor evaluation
YeAnbang Jul 22, 2024
b0e15d5
remove real data path
YeAnbang Jul 22, 2024
9688e19
remove real data path
YeAnbang Jul 22, 2024
a521ffc
Add n_fused as an input from native_module (#5894)
insujang Jul 23, 2024
5fb958c
[FIX BUG] convert env param to int in (#5934)
flymin Jul 24, 2024
2069472
[Hotfix] Fix ZeRO typo #5936
Edenzzzz Jul 25, 2024
ad35a98
[Feature] Add a switch to control whether the model checkpoint needs …
zhurunhua Jul 26, 2024
8a3ff4f
fix style
YeAnbang Jul 26, 2024
de1bf08
fix style
YeAnbang Jul 26, 2024
6fd9e86
fix style
YeAnbang Jul 29, 2024
c8332b9
Merge pull request #5922 from hpcaitech/kto
YeAnbang Jul 29, 2024
9664b1b
[shardformer] hotfix attn mask (#5945)
ver217 Jul 29, 2024
7b38964
[shardformer] hotfix attn mask (#5947)
ver217 Jul 29, 2024
bcf0181
[Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
LRY89757 Jul 30, 2024
0608921
[zero] hotfix update master params (#5951)
ver217 Jul 30, 2024
09c5f72
[release] update version (#5952)
ver217 Jul 31, 2024
30f4e31
[Chat] Fix lora (#5946)
YeAnbang Jul 31, 2024
66fbf2e
Update README.md (#5958)
YeAnbang Jul 31, 2024
1aeb5e8
[hotfix] Remove unused plan section (#5957)
TongLi3701 Jul 31, 2024
f9b6fcf
[test] add mixtral for sequence classification
botbw Jul 2, 2024
0b76b57
[test] add mixtral transformer test
botbw Jul 2, 2024
8ae8525
[moe] fix plugin
botbw Jul 2, 2024
a249e71
[test] mixtra pp shard test
botbw Jul 4, 2024
0fad23c
[chore] handle non member group
botbw Jul 5, 2024
46c069b
[zero] solve hang
botbw Jul 5, 2024
37443cc
[test] pass mixtral shardformer test
botbw Jul 8, 2024
b5bfeb2
[moe] implement transit between non moe tp and ep
botbw Jul 8, 2024
13b48ac
[zero] solve hang
botbw Jul 9, 2024
fe24789
[misc] solve booster hang by rename the variable
Hz188 Jul 9, 2024
5ed5e8c
solve hang when parallel mode = pp + dp
Hz188 Jul 11, 2024
e28e053
[moe] implement submesh initialization
botbw Jul 11, 2024
9b9b76b
[moe] add mixtral dp grad scaling when not all experts are activated
botbw Jul 12, 2024
014faf6
[chore] manually revert unintended commit
botbw Jul 12, 2024
8dbb868
[chore] trivial fix
botbw Jul 12, 2024
102b784
[chore] arg pass & remove drop token
botbw Jul 12, 2024
0b5bbe9
[test] add mixtral modelling test
botbw Jul 15, 2024
dc583aa
[moe] implement tp
botbw Jul 16, 2024
74eccac
[moe] test deepseek
botbw Jul 16, 2024
3e2b613
[moe] clean legacy code
botbw Jul 16, 2024
404b16f
[Feature] MoE Ulysses Support (#5918)
Hz188 Jul 18, 2024
09d6280
[chore] minor fix
botbw Jul 18, 2024
877d94b
[moe] init moe plugin comm setting with sp
botbw Jul 18, 2024
2cddeac
moe sp + ep bug fix
Hz188 Jul 18, 2024
7077d38
[moe] finalize test (no pp)
botbw Jul 18, 2024
803878b
[moe] full test for deepseek and mixtral (pp + sp to fix)
botbw Jul 19, 2024
46037c2
[chore] minor fix after rebase
botbw Jul 19, 2024
52d346f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2024
70c9924
[chore] solve moe ckpt test failure and some other arg pass failure
botbw Jul 22, 2024
74b03de
[moe] remove ops
botbw Jul 22, 2024
067e18f
[test] fix test: test_zero1_2
botbw Jul 22, 2024
96d0fbc
[bug] fix: somehow logger hangs the program
botbw Jul 23, 2024
b2952a5
[moe] deepseek moe sp support
Hz188 Jul 23, 2024
6c39f0b
[test] add check
botbw Jul 23, 2024
c3dc9b4
[deepseek] replace attn (a workaround for bug in transformers)
botbw Jul 23, 2024
59bcf56
[misc] skip redunant test
Hz188 Jul 24, 2024
034020b
[misc] remove debug/print code
Hz188 Jul 24, 2024
cb01c0d
[moe] refactor mesh assignment
botbw Jul 25, 2024
5b4c123
Revert "[moe] implement submesh initialization"
botbw Jul 25, 2024
606b089
[chore] change moe_pg_mesh to private
botbw Jul 25, 2024
12d043c
[misc] remove incompatible test config
Hz188 Jul 25, 2024
70793ce
[misc] fix ci failure: change default value to false in moe plugin
Hz188 Jul 25, 2024
7e737df
[misc] remove useless condition
Hz188 Jul 25, 2024
f7c5485
[chore] docstring
botbw Jul 25, 2024
7bedd03
[moe] remove force_overlap_comm flag and add warning instead
botbw Jul 25, 2024
65daa87
[doc] add MoeHybridParallelPlugin docstring
botbw Jul 26, 2024
d1d1ab8
[moe] solve dp axis issue
botbw Jul 26, 2024
62cdac6
[chore] remove redundant test case, print string & reduce test tokens
botbw Jul 30, 2024
19d1510
[feat] Dist Loader for Eval (#5950)
TongLi3701 Aug 2, 2024
75c9636
[lora] lora support hybrid parallel plugin (#5956)
wangbluo Aug 2, 2024
82c8475
fp8 operators for compressed communication
BurkeHulk Jul 1, 2024
99a9bf3
fix scaling algorithm in FP8 casting
BurkeHulk Jul 12, 2024
7052579
support fp8 communication in pipeline parallelism
BurkeHulk Jul 12, 2024
c30a24e
add fp8_communication flag in the script
BurkeHulk Jul 12, 2024
3b1c861
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2024
fb9486c
fix typo
GuangyaoZhang Jul 10, 2024
e60fcdd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2024
778513e
shardformer fp8
GuangyaoZhang Jul 8, 2024
259e696
fix rebase
GuangyaoZhang Jul 17, 2024
f7c7273
remove all to all
GuangyaoZhang Jul 17, 2024
afe4200
fix shardformer fp8 communication training degradation
GuangyaoZhang Jul 18, 2024
1e7293f
[fp8] support all-gather flat tensor (#5932)
ver217 Jul 24, 2024
047feb9
Merge branch 'feature/fp8_comm' into feature/fp8_comm
flybird11111 Aug 2, 2024
811d5af
Merge branch 'feature/fp8_comm' into feature/fp8_comm
flybird11111 Aug 5, 2024
0e6e488
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
7e0c777
fix
flybird11111 Aug 5, 2024
6f29436
Update low_level_optim.py
flybird11111 Aug 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .compatibility
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
2.1.0-12.1.0
2.2.2-12.1.0
2.3.0-12.1.0
32 changes: 9 additions & 23 deletions .github/workflows/compatiblity_test_on_dispatch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,41 +55,27 @@ jobs:
steps:
- name: Install dependencies
run: |
pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2
with:
repository: hpcaitech/TensorNVMe
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
path: TensorNVMe
- name: Install tensornvme
run: |
cd TensorNVMe
apt update && apt install -y cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
pip install -U pip setuptools==68.2.2 wheel --user

- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Download cub for CUDA 10.2
run: |
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')

# check if it is CUDA 10.2
# download cub
if [ "$CUDA_VERSION" = "10.2" ]; then
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
fi
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt
pip install --no-cache-dir -r requirements/requirements-test.txt

- name: Install tensornvme
run: |
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git

- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest --durations=0 tests
env:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
33 changes: 9 additions & 24 deletions .github/workflows/compatiblity_test_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,42 +49,27 @@ jobs:
steps:
- name: Install dependencies
run: |
pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2
with:
repository: hpcaitech/TensorNVMe
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
path: TensorNVMe
- name: Install tensornvme
run: |
cd TensorNVMe
apt update && apt install -y cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
pip install -U pip setuptools==68.2.2 wheel --user

- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Download cub for CUDA 10.2
run: |
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')

# check if it is CUDA 10.2
# download cub
if [ "$CUDA_VERSION" = "10.2" ]; then
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
fi

- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt
pip install --no-cache-dir -r requirements/requirements-test.txt

- name: Install tensornvme
run: |
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git

- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest --durations=0 tests
env:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
33 changes: 7 additions & 26 deletions .github/workflows/compatiblity_test_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,47 +43,28 @@ jobs:
steps:
- name: Install dependencies
run: |
apt update && apt install -y cmake
pip install -U pip setuptools==68.2.2 wheel --user

- uses: actions/checkout@v2
with:
repository: hpcaitech/TensorNVMe
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
path: TensorNVMe

- name: Install tensornvme
run: |
cd TensorNVMe
apt update && apt install -y cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}

- name: Download cub for CUDA 10.2
run: |
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')

# check if it is CUDA 10.2
# download cub
if [ "$CUDA_VERSION" = "10.2" ]; then
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
fi

- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt
pip install --no-cache-dir -r requirements/requirements-test.txt

- name: Install tensornvme
run: |
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git

- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest --durations=0 tests
env:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors

Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/run_chatgpt_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
mkdir sft_data
mkdir prompt_data
mkdir preference_data
mkdir kto_data
./tests/test_data_preparation.sh
./tests/test_train.sh
env:
Expand All @@ -61,3 +62,4 @@ jobs:
SFT_DATASET: ./sft_data
PROMPT_DATASET: ./prompt_data
PREFERENCE_DATASET: ./preference_data
KTO_DATASET: ./kto_data
4 changes: 3 additions & 1 deletion applications/Colossal-LLaMA/prepare_sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
from multiprocessing import cpu_count

from colossal_llama.dataset.conversation import LLaMA2_Conv
from colossal_llama.dataset.conversation import LLaMA2_Conv, LLaMA3_Conv
from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
from datasets import dataset_dict, load_dataset
from transformers import AddedToken, AutoTokenizer
Expand Down Expand Up @@ -75,6 +75,8 @@ def main():
# Prepare to the tokenizer.
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)

default_conversation = LLaMA3_Conv

# Fix </s> split issue: https://github.com/huggingface/transformers/issues/23833
if args.llama_version == 2:
tokenizer.add_tokens(AddedToken("</s>", normalized=False, special=True), special_tokens=True)
Expand Down
18 changes: 15 additions & 3 deletions applications/Colossal-LLaMA/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def main() -> None:
parser.add_argument("--zero", type=int, default=1)
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
parser.add_argument(
"--skip_save_each_epoch",
action="store_true",
default=False,
help="skip saving the model checkpoint after each epoch is completed.",
)
args = parser.parse_args()

with open(args.config_file, "w") as f:
Expand Down Expand Up @@ -370,11 +376,17 @@ def main() -> None:
)
total_loss.fill_(0.0)
pbar.update()

# Save modeling.

if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
step + 1
) == len(dataloader):
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)

if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)

if save_model_condition:
coordinator.print_on_master("\nStart saving model checkpoint with running states")

if args.use_neft:
Expand Down
3 changes: 3 additions & 0 deletions applications/ColossalChat/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ docs/.build
examples/wandb/
examples/logs/
examples/output/
examples/training_scripts/logs
examples/training_scripts/wandb
examples/training_scripts/output

examples/awesome-chatgpt-prompts/
temp/
Expand Down
74 changes: 48 additions & 26 deletions applications/ColossalChat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
- [Open QA](#open-qa)
- [Limitation for LLaMA-finetuned models](#limitation)
- [Limitation of dataset](#limitation)
- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization)
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [FAQ](#faq)
- [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq)
Expand Down Expand Up @@ -135,17 +139,15 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the
{"messages":
[
{
"from": "human",
"from": "user",
"content": "what are some pranks with a pen i can do?"
},
{
"from": "assistant",
"content": "Are you looking for practical joke ideas?"
},
...
]
},
...
]
```

Expand All @@ -171,23 +173,20 @@ Below shows the preference dataset format used in training the reward model.
"from": "human",
"content": "Introduce butterflies species in Oregon."
}
]
],
"chosen": [
{
"from": "assistant",
"content": "About 150 species of butterflies live in Oregon, with about 100 species are moths..."
},
...
],
"rejected": [
{
"from": "assistant",
"content": "Are you interested in just the common butterflies? There are a few common ones which will be easy to find..."
},
...
]
},
...
]
```

Expand Down Expand Up @@ -216,7 +215,6 @@ PPO uses two kind of training data--- the prompt data and the sft data (optional
"from": "human",
"content": "what are some pranks with a pen i can do?"
}
...
]
},
]
Expand Down Expand Up @@ -262,9 +260,8 @@ experience buffer size
= train_batch_size * accumulation_steps * num_tp_group
```

## Alternative Option For RLHF: Direct Preference Optimization

For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO.
## Alternative Option For RLHF: Direct Preference Optimization (DPO)
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in this [paper](https://arxiv.org/abs/2305.18290), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO. Read this [README](./examples/README.md) for more information.

### DPO Training Stage1 - Supervised Instructs Tuning

Expand All @@ -277,6 +274,15 @@ For DPO training, you only need the preference dataset. Please follow the instru
#### Step 2: Training
You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. More detais can be found in [example guideline](./examples/README.md).

## Alternative Option For RLHF: Simple Preference Optimization (SimPO)
Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2405.14734) is similar to DPO but it abandons the use of the reference model, which makes the training more efficient. It also adds a reward shaping term called target reward margin to enhance training stability. It also use length normalization to better align with the inference process. Read this [README](./examples/README.md) for more information.

## Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO)
Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information.

## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.

### Inference Quantization and Serving - After Training

We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
Expand Down Expand Up @@ -441,20 +447,6 @@ If you only have a single 24G GPU. Generally, using lora and "zero2-cpu" will be
If you have multiple GPUs each has very limited VRAM, say 8GB. You can try the `3d` for the plugin option, which supports tensor parellelism, set `--tp` to the number of GPUs that you have.
</details>

## The Plan

- [x] implement PPO fine-tuning
- [x] implement training reward model
- [x] support LoRA
- [x] support inference
- [x] support llama from [facebook](https://github.com/facebookresearch/llama)
- [x] implement PPO-ptx fine-tuning
- [x] support flash-attention
- [x] implement DPO fine-tuning
- [ ] integrate with Ray
- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL),
- [ ] support chain-of-thought by [langchain](https://github.com/hwchase17/langchain)

### Real-time progress

You will find our progress in github [project broad](https://github.com/orgs/hpcaitech/projects/17/views/1).
Expand Down Expand Up @@ -522,7 +514,7 @@ Coati is developed by ColossalAI Team:
- [Fazzie](https://fazzie-key.cool/about/index.html) Contributing to the algorithm and development for SFT.
- [ofey404](https://github.com/ofey404) Contributing to both front-end and back-end development.
- [Wenhao Chen](https://github.com/CWHer) Contributing to subsequent code enhancements and performance improvements.
- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored version with updated acceleration framework, LoRA, DPO and PPO.
- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored PPO version with updated acceleration framework. Add support for DPO, SimPO, ORPO.

The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
- [Zangwei Zheng](https://github.com/zhengzangw)
Expand Down Expand Up @@ -572,6 +564,36 @@ We also appreciate the valuable suggestions provided by [Jian Hu](https://github
journal = {GitHub repository},
howpublished = {\url{https://github.com/XueFuzhao/InstructionWild}},
}

@misc{meng2024simposimplepreferenceoptimization,
title={SimPO: Simple Preference Optimization with a Reference-Free Reward},
author={Yu Meng and Mengzhou Xia and Danqi Chen},
year={2024},
eprint={2405.14734},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2405.14734},
}

@misc{rafailov2023directpreferenceoptimizationlanguage,
title={Direct Preference Optimization: Your Language Model is Secretly a Reward Model},
author={Rafael Rafailov and Archit Sharma and Eric Mitchell and Stefano Ermon and Christopher D. Manning and Chelsea Finn},
year={2023},
eprint={2305.18290},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2305.18290},
}

@misc{hong2024orpomonolithicpreferenceoptimization,
title={ORPO: Monolithic Preference Optimization without Reference Model},
author={Jiwoo Hong and Noah Lee and James Thorne},
year={2024},
eprint={2403.07691},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2403.07691},
}
```

## Licenses
Expand Down
Loading
Loading