Skip to content

Commit

Permalink
merge upstream changes and add support for torchbench (#9)
Browse files Browse the repository at this point in the history
* Set `record_shapes=True` for profiler

ghstack-source-id: 6f1ed49d15ce311f1bf118820965cdb5309a8030
Pull Request resolved: pytorch#419

* Improved `repeat_kv` eager perf

ghstack-source-id: 39e484954814e61cdfb2ba661f0a98c83bc0ce60
Pull Request resolved: pytorch#418

* Adding FSDP Memory Tracking and Estimation

ghstack-source-id: c8ed20fc585957bd164dd963307616a53991615d
Pull Request resolved: pytorch#425

* Adding integration test for FSDP Memory Tracking and Estimation

ghstack-source-id: cc224db8951ec7a133fd769845a4765cbedc6454
Pull Request resolved: pytorch#426

* by default disable heavy memory profiling

ghstack-source-id: cad7b3c41fd60ec19c0e6e7d058e8aa00602a187
Pull Request resolved: pytorch#430

* Add the option to turn on async-TP

ghstack-source-id: 0a03379eeb3a63b2d1ad4dff84d0e61ca82b1bbf
Pull Request resolved: pytorch#429

* Modifying memory estimation options and minor changes

ghstack-source-id: 5f09824cddaed6585cc094095e1e95dd070d76f4
Pull Request resolved: pytorch#435

* add comment pointing to Sequence Parallel optimization example

ghstack-source-id: 6fa0dcd4bca876e10a6a8349283fb940a59ad234
Pull Request resolved: pytorch#438

* switch float8 logic from Float8DynamicLinear to Float8Linear (pytorch#436)

Summary:

After pytorch-labs/float8_experimental#300,
`Float8Linear` with default settings is equivalent to
`Float8DynamicLinear`. This PR changes `torchtitan` to use
`Float8Linear`.

To support the new UX of `float8_experimental` better, I also switched
the `fp8_linear` configuration to be a boolean on whether to swap the
linears or not. In the future we can add new options on how to configure
each linear (scaling type, scaling granularity, etc) - saving that for a
future PR.

Test Plan:

```
// run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs,
// verify performance and loss values do not change meaningfully between
// baseline and this PR

// baseline (before this PR)
// 1. compile, bf16
// 2. compile, float8
// 3. compile, float8, fdsp_fp8_allgather=True
// 4. compile, float8, fdsp_fp8_allgather=True, tp=2
// logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce

// experiment (this PR): repeat all of the above, but with Float8Linear
// logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Removed `_experimental_support_context_fn_in_torch_utils_checkpoint`

ghstack-source-id: 50b2d0c2b4c22e2f045cafd8630c16f3a8c6d35f
Pull Request resolved: pytorch#444

* Reordered TP parallel plan to follow execution order

ghstack-source-id: b4924952adeb5f16d08b60faa54690762841c422
Pull Request resolved: pytorch#445

* Made some stylistic changes to `apply_dp`

ghstack-source-id: fb78e9eb8aa406ba87d6ad6cf2229c1027dae42f
Pull Request resolved: pytorch#446

* Refactored activation checkpointing

ghstack-source-id: 785c7e47651cda97ea22d0147d14b8d061ce042d
Pull Request resolved: pytorch#447

* compiled RMSNorm

ghstack-source-id: c4efb81ec6acc5442955908cc376df3e6d889af3
Pull Request resolved: pytorch#442

* Renamed parallel styles for transformer block weights

ghstack-source-id: 5fb0bf3d08cacf27242ec0f85d5dd3cdc03b739e
Pull Request resolved: pytorch#448

* Added type annotations and more stylistic changes

ghstack-source-id: 1bd5b9d5abc8644785132f8eb2baaf8b1cfc5fb5
Pull Request resolved: pytorch#449

* [Cleanup] Remove libuv from run_llama_train.sh

libuv is now enabled by default.

we can proably do without the educational blurb there, and don't need
the env either since the default has landed.

ghstack-source-id: 68c8d2abe7eb0777e2add8df7634367c31b7ec06
Pull Request resolved: pytorch#453

* [Cleanup] Organize run_llama_train.sh options

Just a little code motion but it looks cleaner to me this way

ghstack-source-id: 055fbd557cd9cf189e6b9bd6a7048f1204e1dc5c
Pull Request resolved: pytorch#454

* [Cleanup] Split run_llama_train.sh and run_memory_estimation.sh

Make each script simpler to read

ghstack-source-id: ba3aa65feb6e304736c73daf5bc8ab5fb254f196
Pull Request resolved: pytorch#455

* [Cleanup] Remove unused TRAINER_DIR

This argument seems to be left over from older times- it is not used
anywhere in the codebase.

ghstack-source-id: abbcf82ed4d1b8fbb71c6a6b48acbc1296dbec64
Pull Request resolved: pytorch#456

* Add educational code pointers to top level README

ghstack-source-id: 522aa2fa0bf1679f55d9f3a8a38fdcd319d5e3df
Pull Request resolved: pytorch#457

* enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather (pytorch#413)

we have landed fp8 all-gather optimizations in float8_experimental
pytorch-labs/float8_experimental#266

this PR proposes torchtitan changes. also include fp8 in CI
```
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
```

FSDP2 fp8 all-gather are added to CI
```
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp
```

TP fp8 all-gather are locally tested. will add them to CI after
uploading a new tokenizer with vacab size 2560 (divisible by 16)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2
```

precompute scales after optimizer.step
<img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM"
src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817">

FSDP2 pre-all-gather do not have any small all-reduces
<img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM"
src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08">

TODO
* upload tokenizer with vacab size 2560 to enable CI on TP fp8
all-gather
* torch.compile complains about fp8
* add delayed scaling and brainstorm about best config option to express
fp8
* compare perf between delayed scaling and dynamic scaling
https://github.com/pytorch-labs/float8_experimental/pull/312/files

* import float8_experimental only when fp8 is enabled and install it in CI (pytorch#464)

make sure to only import float8_experimental when fp8 is enabled

for 4 gpu CI, make sure we can import float8_experimental correctly in
CI

`python -m pip install
git+https://github.com/pytorch-labs/float8_experimental.git`

* skip fp8 CI on non-H100 GPUs (pytorch#465)

skip fp8 tests on non-H100 GPUs by checking
`torch.cuda.get_device_capability() >= (9, 0)`

this makes 4 GPU CI healthy again

* clean up float8 configs in torchtitan (pytorch#466)

Summary:

1. standardizes on `float8` instead of `fp8` for config names
2. removes usage of non-public objects such as `Float8Linear`

Test Plan:

```
with-proxy NGPU=1 CUDA_VISIBLE_DEVICES=7 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.compile --training.enable_float8_linear
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Add support of DDP and experimental CompiledAutograd

Summary:
Address the comments in pytorch#319 and resubmit the PR to fit the current code base.

Test Plan:
```
CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600   --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000
```

ghstack-source-id: 81dc85d42df13df4ed727bebd825681879af936b
Pull Request resolved: pytorch#432

* add torch.compile + FSDP2 float8 all-gather in CI (pytorch#468)

fixed my bug in float8_experimental. now we can torch.compile
transfromer blocks with FSDP float8 all-gather
pytorch-labs/float8_experimental#321

local test: `CONFIG_FILE="./train_configs/debug_model.toml"
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp --training.compile`

profiler traces: I can see compiled region in cpu thread and float8
malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream
<img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM"
src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129">

* [float8] keep model.output as `nn.Linear` (high precision, not fp8) (pytorch#469)

**keep model.output as nn.Linear**: it's a common practice to NOT apply
fp8 on final output layer
* specify `skip_fqn_list` in swapping
* when applying TP to model.output, use plain `ColwiseParallel` instead
of `Float8ColwiseParallel`

credit to @awgu, we do not need tokentizer vacab size to be divisible by
16 pytorch#461

1D TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4`

1D TP + float8 all-gather, compile mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.data_parallel_degree 1 --training.tensor_parallel_degree 4
--training.compile`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2`

2D FSDP2 + TP + float8 all-gather, eager mode:
`CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp
--training.tensor_parallel_degree 2 --training.compile`

1D TP + float8 all-gather trace: see float8 and all-gather in the trace
<img width="1611" alt="Screenshot 2024-07-19 at 1 16 59 PM"
src="https://github.com/user-attachments/assets/9a95dfd9-40e0-4133-b2bb-e22ddf5b8472">

2D + float8 all-gather trace: see float8 and FSDP collectives and TP
collectives
<img width="1038" alt="Screenshot 2024-07-19 at 1 29 59 PM"
src="https://github.com/user-attachments/assets/6a34bcaa-bcae-402b-9994-cc892554fec7">

* remove CI for FSDP2 + fp8 all-gather (pytorch#470)

per discussion from
pytorch#469 (comment)

we are planning BC breaking changes in float8_experimental. remove CI
for FSDP2 + fp8 all-gather for now. When public APIs are finalized, we
can discuss bringing it back

* dynamically update torch.compile cache config to ensure async tp support, enhance async tp UX (pytorch#471)

This PR adds some enhancements for supporting async tp:

1 - if async tp is active, auto updates the torch.dynamo cache limit to
10K. If this is not updated, async tp will not be activated on larger
models as it will quietly stop compilation due to 'cache limit reached'
with no info for the user.
This config update is logged. 

2 - if async tp is enabled, verifies that torch.compile is set to true
for this job config. If not, it warns and then activates torch.compile
to ensure user gets working async tp. (see WARNING in below screenshot)

<img width="1345" alt="Screenshot 2024-07-20 at 4 33 04 PM"
src="https://github.com/user-attachments/assets/26e5a48e-4bb8-4f33-b1b5-8939c1517c1d">

3 - Updates the 'Applied Tensor Parallel' to the model to be 'Applied
Async Tensor Parallel' when async tp is active to make it clear in the
logs which TP is active. (see above screenshot)

* Fix 8gpu PP failure due to 2D DCP disablement

DCP recently added safeties to avoid using it for 2D/3D since strided
sharding (a feature needed for safe 2D/3D resharding) is not ready yet.

PP uses DCP to load a seed checkpoint.  Disabling the safety mechanism
is enough to make 3D/PP still work (for the case where we train from the
beginning or do not re-shard.

(Resharding refers to saving a checkpoint from one world
size/parallelism config and loading/resuming under a different one).

ghstack-source-id: c069d2186c79517c72f5b3c99485cebdc15df08f
Pull Request resolved: pytorch#460

* update float8 integration after UX changes (pytorch#484)

Summary:

float8_experimental landed various BC-breaking UX changes last week.
This PR updates torchtitan to work with the version of
float8_experimental after
pytorch-labs/float8_experimental#332 and
pytorch-labs/float8_experimental#337

Test Plan:

```
with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Re-enable FSDP2 Mem Tracker integration tests

ghstack-source-id: 8344603f7a5596cb2909c9bf04dd1b9e4730c9b8
Pull Request resolved: pytorch#485

* Used `partial` instead of global vars for LR scheduling

ghstack-source-id: 12c4418b0574d93e1441f4ca3d1de79c8aad7a40
Pull Request resolved: pytorch#487

* [EZ] Add logs for some basic training params so that we can verify in… (pytorch#491)

As title, while testing on 405B model, I found that we need to somehow
need the logs for some training params. So added some here. Tested
locally and the logging is shown as in the screenshot:


<img width="900" alt="image"
src="https://github.com/user-attachments/assets/b94e34f5-3e88-4c5f-94ed-75f50dde9786">

* make float8 scaling type configurable (pytorch#489)

Summary:

Adds config options to configure float8 scaling type for input, weight,
grad_output.

Performance is not ideal yet, but that's because we have not optimized
it.

Test Plan:

```
// repeat for input, weight, grad_out
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.float8_scaling_type_weight delayed --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [PP] add flexible interleaved 1f1b schedule pytorch#490 (pytorch#493)

This was approved in pytorch#490, but
merged into the wrong branch, merging this into main

* move float8 callsites to torchao.float8 (pytorch#492)

Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

* [BE][1/n] simplify train.py

ghstack-source-id: 3879e764e7b33afde5d778810c71d1d2a8f82f6d
Pull Request resolved: pytorch#494

* [BE][2/n] use proper method signatures in parallelize_llama

ghstack-source-id: 17a1ee9f03f13423a30183c5c8d7ad30f8c8dbfc
Pull Request resolved: pytorch#495

* [BE][3/n] wrap fp8 logic using Float8Handler

ghstack-source-id: e94c7f6f4fad87c5432262c54beabd02de5541b8
Pull Request resolved: pytorch#496

* Bring LLaMa 3.1 405B to TorchTitan family (pytorch#481)

With the official launch of LLaMa 3.1 model, we want to add the config
to TorchTitan. Of course, there are more work to be done, but we want to
go an incremental way. So more PRs will be needed.

For now, we try on 128 GPUs with current config (TP=8, FSDP=16). The
perf number is wps: 109 mfu: 29%.

Loss curve for 3000 steps with 600 warmup (lr = 0.8e-4).
<img width="1037" alt="image"
src="https://github.com/user-attachments/assets/f57dd3fa-07d8-4ef4-8f68-8f7a08e9652e">


Loss curve for 3000 steps with 600 warmup (lr = 1.1e-4).

![image](https://github.com/user-attachments/assets/429b9738-94cb-4b37-90ef-049a5587ddd0)

* [TP] Infer local n_heads instead of ad-hoc model changes

ghstack-source-id: 587e3d6e5270714ca734b8031ce41a962e6394ea
Pull Request resolved: pytorch#498

* some compile-related updates

ghstack-source-id: 63af8025c184fd5ad34f2f57bf78a37dda2cd33d
Pull Request resolved: pytorch#443

* [EZ][405B] Use scientific notation for 405B model lr (pytorch#504)

As title, use `8e-5` rather than `0.8e-4`.

* [BE][4/n] split pipeline_llama into a separate file

ghstack-source-id: 5ebb4adf3152f413fa33a923c272c9aa3ce1f775
Pull Request resolved: pytorch#499

* [fix] float8 should be applied on all model_parts

ghstack-source-id: 52ed6836de39e82c4c5824a40ecfc1d9ec7ed2bd
Pull Request resolved: pytorch#500

* Add warning to compile rmsnorm (pytorch#505)

as titled, add warning to compile rmsnorm as it's not fully ready yet,
i.e. this issue pytorch#497

We can remove this warning once we fix the issue

* add float8 to README (pytorch#509)

add float8 link in README so we can redirect people from dev-discuss
post to torchtitan repo


README looks like this after rendering
<img width="518" alt="Screenshot 2024-08-06 at 5 42 10 PM"
src="https://github.com/user-attachments/assets/50af99d7-93be-459a-89d7-8c08b8fb95d4">

float8.md looks like this
<img width="563" alt="Screenshot 2024-08-06 at 5 04 17 PM"
src="https://github.com/user-attachments/assets/06d30aad-4133-4cec-9037-cfcf155b45c4">

I tried the command locally and traces are looking good
<img width="726" alt="Screenshot 2024-08-06 at 5 00 00 PM"
src="https://github.com/user-attachments/assets/bdfa3d7e-efe1-4009-92a1-0f5c310013fb">

* address TODOs as 2D recompiles is fixed

ghstack-source-id: 2927f0a8082171da3e9f59a5d04f8325cbdf3653
Pull Request resolved: pytorch#508

* [BE][5/n] simply pp vs. non-pp set up

ghstack-source-id: 003bfbfbcf1511ddbd18e15d031b39f597d8e7db
Pull Request resolved: pytorch#510

* [BE][6/n] replace large c4_mini datasets by c4_test with the first 2K entries

ghstack-source-id: 319f4961b092778703101b98937803073132afa1
Pull Request resolved: pytorch#512

* Create composability.md (pytorch#511)

Explain the rationale and challenges behind certain changes we made to
llama model to support 3D parallelism.

---------

Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>

* depend on torchdata 0.8.0 instead of nightly

ghstack-source-id: 1965d3122885fed3c28e2e058c55581187e7816c
Pull Request resolved: pytorch#513

* add support for torchbench

---------

Co-authored-by: Andrew Gu <andgu@fb.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@meta.com>
Co-authored-by: Yifu Wang <yifu@fb.com>
Co-authored-by: Vasiliy Kuznetsov <vkuzo@users.noreply.github.com>
Co-authored-by: Will Constable <whc@meta.com>
Co-authored-by: Wei (Will) Feng <134637289+weifengpy@users.noreply.github.com>
Co-authored-by: Chien-Chin Huang <chienchin@fb.com>
Co-authored-by: Less Wright <lessw@etrillium.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@fb.com>
Co-authored-by: Hugo <6937752+fduwjj@users.noreply.github.com>
Co-authored-by: Howard Huang <howardhuang96@gmail.com>
Co-authored-by: Ke Wen <kw2501@meta.com>
Co-authored-by: Wanchao <wanchaol@users.noreply.github.com>
Co-authored-by: Will Constable <willconstable@gmail.com>
  • Loading branch information
15 people committed Aug 13, 2024
1 parent d86885f commit 42589ae
Show file tree
Hide file tree
Showing 50 changed files with 3,672 additions and 945 deletions.
1 change: 1 addition & 0 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torch >= 2.3.0
torchdata >= 0.8.0
datasets >= 2.19.0
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ jobs:
pip config --user set global.progress_bar off
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4
1 change: 0 additions & 1 deletion .github/workflows/integration_test_8gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,5 @@ jobs:
pip config --user set global.progress_bar off
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded --ngpu 8
1 change: 0 additions & 1 deletion .github/workflows/unit_test_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,4 @@ jobs:
pip config --user set global.progress_bar off
pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
pytest test --cov=. --cov-report=xml --durations=20 -vv
23 changes: 16 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ Our guiding principles when building `torchtitan`:

[![Welcome to torchtitan!](assets/images/titan_play_video.png)](https://youtu.be/ee5DOEqD35I?si=_B94PbVv0V5ZnNKE "Welcome to torchtitan!")

### Dive into the code

You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first:
* [train.py](https://github.com/pytorch/torchtitan/blob/main/train.py) - the main training loop and high-level setup code
* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data Parallel, Tensor Parallel, activation checkpointing, and `torch.compile` to the model
* [torchtitan/parallelisms/pipeline_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/pipeline_llama.py) - helpers for applying Pipeline Parallel to the model
* [torchtitan/checkpoint.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints
* [torchtitan/float8.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/float8.py) - utils for applying Float8 techniques
* [torchtitan/models/llama/model.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants)

## Pre-Release Updates:
#### (4/25/2024): `torchtitan` is now public but in a pre-release state and under development.
Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes from scratch. `torchtitan` is tested and verified with the PyTorch nightly version `torch-2.4.0.dev20240412`. (We recommend latest PyTorch nightly).
Expand All @@ -33,18 +43,18 @@ Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes
6. Learning rate scheduler, meta init, Optional Fused RMSNorm
7. All options easily configured via [toml files](train_configs/)
8. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine tuning
9. [Float8 support](docs/float8.md)

We report our [Performance](docs/performance.md) verified on 64 A100 GPUs


### Coming soon

1. Async checkpointing
2. FP8 support
3. Context Parallel
4. 3D Pipeline Parallel
5. `torch.compile` support
6. Scalable data loading solution
2. Context Parallel
3. 3D Pipeline Parallel
4. `torch.compile` support
5. Scalable data loading solution


## Installation
Expand All @@ -54,7 +64,6 @@ git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118
pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
```

### Downloading a tokenizer
Expand All @@ -66,7 +75,7 @@ Once you have confirmed access, you can run the following command to download th
```bash
# Get your HF token from https://huggingface.co/settings/tokens

# llama3 tokenizer.model
# llama3 or 3.1 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=...

# llama2 tokenizer.model
Expand Down
232 changes: 232 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import time
from datetime import timedelta

import torch
from torch.distributed.elastic.multiprocessing.errors import record

from torchbenchmark.util.experiment.instantiator import (
load_model,
TorchBenchModelConfig,
)
from torchbenchmark.util.experiment.metrics import get_model_flops
from torchbenchmark.util.input import input_cast

from torchtitan import utils
from torchtitan.checkpoint import TrainState
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import init_logger, logger
from torchtitan.metrics import build_gpu_memory_monitor
from torchtitan.parallelisms import ParallelDims
from torchtitan.parallelisms.parallelize_llama import torch_spmd_parallelize
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling


# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
@record
def main(job_config: JobConfig):
init_logger()
logger.info(f"Starting job: {job_config.job.description}")

# used for colorful printing
color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor

# take control of garbage collection to avoid stragglers
gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)

# init distributed
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.training.data_parallel_type,
)
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
utils.init_distributed(job_config)
# initialize GPU memory monitor and get peak flops for MFU calculation
gpu_memory_monitor = build_gpu_memory_monitor()
gpu_peak_flops = utils.get_peak_flops(gpu_memory_monitor.device_name)

# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0

if parallel_dims.pp_enabled:
pp_mesh = world_mesh["pp"]

model_name = job_config.model.name

# initiate model from torchbench
config = TorchBenchModelConfig(
name=model_name,
test="train",
device="cuda",
batch_size=job_config.training.batch_size,
extra_args=[],
)
model_flops = get_model_flops(config)
benchmark_model = load_model(config)
model, _ = benchmark_model.get_module()

# TODO: there seems to be a bug with dtype conversion (e.g. use resnet50)
# cast input dtype if needed
param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param]
input_cond = lambda x: x.dtype == torch.float32
input_action = lambda x: x.to(param_dtype)
if hasattr(benchmark_model, "example_inputs"):
benchmark_model.example_inputs = input_cast(
input_cond, input_action, benchmark_model.example_inputs
)
else:
logger.warning(
f"{model_name} example inputs haven't been cast to {action} yet!"
)

# log model size
model_param_count = utils.get_num_params(model)
logger.info(
f"{color.blue}Model {model_name} "
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
)

# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
model = torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config)

# update model and optimizer after applying parallelisms
benchmark_model.set_module(model)
optimizer = benchmark_model.get_optimizer()
optimizer.add_param_group({"params": model.parameters()})

model.train()

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
logger.info(
f"GPU memory usage for model: "
f"{gpu_mem_stats.max_reserved_gib:.2f}GiB"
f"({gpu_mem_stats.max_reserved_pct:.2f}%)"
)

train_state = TrainState()

# variables used to keep info for metrics logging
losses_since_last_log = []
gpu_memory_monitor.reset_peak_stats()

# train loop
logger.info(
f"Training starts at step {train_state.step + 1}, "
f"with local batch size {job_config.training.batch_size}, "
f"global batch size {job_config.training.batch_size * dp_degree}, "
f"total steps {job_config.training.steps}"
)
with maybe_enable_profiling(
job_config, global_step=train_state.step
) as torch_profiler, maybe_enable_memory_snapshot(
job_config, global_step=train_state.step
) as memory_profiler:
while train_state.step < job_config.training.steps:
train_state.step += 1
gc_handler.run(train_state.step)

torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

# Collect time_ns() instead of time() which does not provide better precision than 1
# second according to https://docs.python.org/3/library/time.html#time.time.
t0 = time.time_ns()
start_event.record()

is_staged = (
hasattr(benchmark_model, "forward")
and hasattr(benchmark_model, "backward")
and hasattr(benchmark_model, "optimizer_step")
)
if is_staged and (getattr(benchmark_model, "train", None) is None):
if optimizer is not None:
optimizer.zero_grad()
loss = benchmark_model.forward()
benchmark_model.backward(loss)
if optimizer is not None:
benchmark_model.optimizer_step()
else:
loss = benchmark_model.train()

end_event.record()
torch.cuda.synchronize()
t1 = time.time_ns()
time_delta = start_event.elapsed_time(end_event), (t1 - t0) / 1_000_000

# log metrics
losses_since_last_log.append(loss)
if (
train_state.step == 1
or train_state.step % job_config.metrics.log_freq == 0
):
losses = [
loss.item() if isinstance(loss, torch.Tensor) else loss
for loss in losses_since_last_log
]
avg_loss, max_loss = sum(losses) / len(losses), max(losses)
if parallel_dims.dp_enabled:
global_avg_loss, global_max_loss = (
utils.dist_mean(avg_loss, dp_mesh),
utils.dist_max(max_loss, dp_mesh),
)
else:
global_avg_loss, global_max_loss = avg_loss, max_loss

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()

logger.info(
f"{color.cyan}step: {train_state.step:2} "
f"{color.green}loss: {global_avg_loss:7.4f} "
f"{color.yellow}memory: {gpu_mem_stats.max_reserved_gib:5.2f}GiB"
f"({gpu_mem_stats.max_reserved_pct:.2f}%) "
f"{color.blue}GPU time: {time_delta[0]:.3f}ms "
f"CPU wall time: {time_delta[1]:.3f}ms{color.reset}"
)

losses_since_last_log.clear()
gpu_memory_monitor.reset_peak_stats()

# signal the profiler that the next profiling step has started
if torch_profiler:
torch_profiler.step()
if memory_profiler:
memory_profiler.step()

# reduce timeout after first train step for faster signal
# (assuming lazy init and compilation are finished)
if train_state.step == 1:
utils.set_pg_timeouts(
timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
world_mesh=world_mesh,
)

if torch.distributed.get_rank() == 0:
logger.info("Sleeping 2 seconds for other ranks to complete")
time.sleep(2)

logger.info("Training completed")


if __name__ == "__main__":
config = JobConfig()
config.parse_args()
main(config)
torch.distributed.destroy_process_group()
2 changes: 0 additions & 2 deletions create_seed_checkpoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

set -ex

export USE_LIBUV=1
TRAINER_DIR=${1:-/home/$USER/local/torchtitan}
NGPU=1
LOG_RANK=0
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}
Expand Down
23 changes: 23 additions & 0 deletions docs/composability.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Building a Clean, Readable Distributed LLM
One of the main goals for TorchTitan was to provide a version of distributed LLM that was not only high performance, but utilized native pytorch techniques and readable code. The challenge is how to compose together so many individual library components (FSDP, TP, PP, FP8, Compile, DCP, ...) just to name a few, and avoid having to make too many changes to the model guts in the process. A lot of the work is behind the scenes, designing individual components to make fewer assumptions, use common abstractions (e.g. DTensor) and generally 'get along'. But we found a few tweaks to the model code invaluable as well, and wanted to share those changes and the rationale for them.



# Making the model "pipeline friendly"
When applying Pipeline Parallelism, you will have to construct nn.Module objects representing the portion of the model that runs on a given pipeline stage. Whether you plan to manually edit your model code, or use techniques like tracing to extract model chunks, a few changes to the original model code can go a long way to making this process easier.

### Simplifying the top-level model forward
Most likely, you can write your model in such a way that the top-level nn.Module owns a sequence of child modules that it calls during forward, delegating most of the complexity to the child module forwards. If you can reduce your top level forward to mostly a for-loop over child module calls, then you'll simplify the pipeline-partitioning task to choosing the set of submodules to keep per stage. If you have non-trivial logic in the top-level forward, you'll have to find a way to patch that logic back onto the resulting pipeline stage model, which can be annoying.

example ([PR #321](https://github.com/pytorch/torchtitan/pull/321)):
we used to slice the `freqs_cis` buffer by `seq_len` in the top level forward, pass that into child modules, and expect that inside the child modules the `seq_len` would match up with the size of other local tensors. But we don't know about whether TP was applied or not when we consider PP splitting and could create a mismatch. Its just as easy to perform the `freqs_cis` slicing inside the child submodule, using the runtime-accurate local `seq_len`, and this sidesteps the issue at PP slicing time.

example ([PR #322])https://github.com/pytorch/torchtitan/pull/322)): We decided to actually reuse the top-level model object on every PP stage, just delete the layers we don't want, and make sure that the top-level forward would do the right thing. This means we don't have to make a separate runtime pp_forward that glues together child modules per stage. The first change was using a moduledict instead of modulelist to store layers. This preserves layer Fully Qualified Names (FQNs) even when deleting some layers - e.g. layers.1 stays layers.1 even if you remove layers.0, which isn't true for a list- this matters for checkpoint save/load. Preserving FQNs is a requirement for using Distributed Checkpointing (DCP) since it uses FQNs as globally unique IDs for sharding metadata. The second change was making the input and output layers optional- if the layer exists, we run it, otherwise we feed the input through to bypass it. With these two changes, we can just (meta)-initialize the whole model, delete the unused parts per stage, then materialize the remaining part on GPU before loading a checkpoint.

# Using a seed checkpoint for init
Initializing the pipeline-parallel model is challenging becuase we assume the model could be so large as to not fit on local GPU (or possibly, even on CPU), and we also want to use the (bitwise) same initialization as we use for 1D or 2D parallel models, to ease debugging or comparisons between runs. It's not that easy to rewrite the original model's `init_weights` function to be tolerant of initializing only some layers, and also serializing initialization operations globally for consistent RNG order.

For now, we sidestep all these problems with a simple but brutal solution: Initialize the whole model on some CPU instance, save a checkpoint file, and then lean on Distributed Checkpointing's "load" functionality to initialize the FQNs that are present on a given PP stage after stage creation. For future work, we consider adding a more elaborate initialization scheme to `torch.pipelining`.

One issue with seed checkpoints is that we rely on initializing _every_ model state from the checkpoint, which means the model can't have any non-persistent buffers, or else we have to specially initialize those in `train.py` after pipeline splitting. `freqs_cis` was originally a non-persistent buffer, and we changed this to persistent in order to load it from the seed checkpoint.

18 changes: 18 additions & 0 deletions docs/float8.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
## Enable Float8 Training on H100s

Please install latest [TorchAO](https://github.com/pytorch/ao/tree/main/torchao/float8) to support float8 dtype
```
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
```

Launch training job with the following command (or alternatively set configs in toml files)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
```
* `--float8.enable_float8_linear`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.

For parallelisms, we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`).

For scaling strategy, we currently support tensor-wise scaling with dynamic scales, and are actively working on tensor-wise scaling with delayed scales. Row-wise scaling is under exploration.
Loading

0 comments on commit 42589ae

Please sign in to comment.