From b502cdcb3ff3f315c3243cf6e6fb2896911deca3 Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com> Date: Fri, 19 Jul 2024 16:07:12 -0700 Subject: [PATCH] [float8] keep model.output as `nn.Linear` (high precision, not fp8) (#469) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **keep model.output as nn.Linear**: it's a common practice to NOT apply fp8 on final output layer * specify `skip_fqn_list` in swapping * when applying TP to model.output, use plain `ColwiseParallel` instead of `Float8ColwiseParallel` credit to @awgu, we do not need tokentizer vacab size to be divisible by 16 https://github.com/pytorch/torchtitan/issues/461 1D TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4` 1D TP + float8 all-gather, compile mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4 --training.compile` 2D FSDP2 + TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.tensor_parallel_degree 2` 2D FSDP2 + TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.tensor_parallel_degree 2 --training.compile` 1D TP + float8 all-gather trace: see float8 and all-gather in the trace Screenshot 2024-07-19 at 1 16 59 PM 2D + float8 all-gather trace: see float8 and FSDP collectives and TP collectives Screenshot 2024-07-19 at 1 29 59 PM --- torchtitan/float8_linear.py | 4 +++- torchtitan/parallelisms/parallelize_llama.py | 19 ++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 50c971ae..770531d5 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -74,7 +74,9 @@ def maybe_build_fp8_linear( ) with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather): swap_linear_with_float8_linear( - model, scaling_type_w=TensorScalingType.DYNAMIC + model, + scaling_type_w=TensorScalingType.DYNAMIC, + skip_fqn_list=["output"], ) logger.info( f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}" diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 33b9d6d3..634c70a0 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -117,7 +117,7 @@ def selective_checkpointing_context_fn(): return module -def get_tp_parallel_strategy( +def get_tp_parallel_strategy_for_transformer_block( job_config: JobConfig, model: nn.Module, ) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: @@ -346,13 +346,6 @@ def apply_tp( """Apply tensor parallelism.""" tp_mesh = world_mesh["tp"] - # Parallel styles used for transformer block linear weights and their - # inputs may be different for float8 linears - ( - rowwise_parallel_weight, - colwise_parallel_weight, - prepare_module_input, - ) = get_tp_parallel_strategy(job_config, model) loss_parallel = parallel_dims.loss_parallel_enabled # 1. Parallelize the embedding and shard its outputs (which are the first @@ -368,7 +361,7 @@ def apply_tp( output_layouts=Shard(1), ), "norm": SequenceParallel(), - "output": colwise_parallel_weight( + "output": ColwiseParallel( input_layouts=Shard(1), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, @@ -376,6 +369,14 @@ def apply_tp( }, ) + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears + ( + rowwise_parallel_weight, + colwise_parallel_weight, + prepare_module_input, + ) = get_tp_parallel_strategy_for_transformer_block(job_config, model) + # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension.