Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8] keep model.output as nn.Linear (high precision, not fp8) #469

Merged
merged 3 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def maybe_build_fp8_linear(
)
with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather):
swap_linear_with_float8_linear(
model, scaling_type_w=TensorScalingType.DYNAMIC
model,
scaling_type_w=TensorScalingType.DYNAMIC,
skip_fqn_list=["output"],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @vkuzo in case API changes in the future for argumetn skip_fqn_list

)
logger.info(
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
Expand Down
19 changes: 10 additions & 9 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def selective_checkpointing_context_fn():
return module


def get_tp_parallel_strategy(
def get_tp_parallel_strategy_for_transformer_block(
job_config: JobConfig,
model: nn.Module,
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
Expand Down Expand Up @@ -346,13 +346,6 @@ def apply_tp(
"""Apply tensor parallelism."""

tp_mesh = world_mesh["tp"]
# Parallel styles used for transformer block linear weights and their
# inputs may be different for float8 linears
(
rowwise_parallel_weight,
colwise_parallel_weight,
prepare_module_input,
) = get_tp_parallel_strategy(job_config, model)
loss_parallel = parallel_dims.loss_parallel_enabled

# 1. Parallelize the embedding and shard its outputs (which are the first
Expand All @@ -368,14 +361,22 @@ def apply_tp(
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": colwise_parallel_weight(
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
},
)

# Parallel styles used for transformer block linear weights and their
# inputs may be different for float8 linears
(
rowwise_parallel_weight,
colwise_parallel_weight,
prepare_module_input,
) = get_tp_parallel_strategy_for_transformer_block(job_config, model)

# Apply tensor + sequence parallelism to every transformer block
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
# by folding (and unfolding) the batch dimension and the sequence dimension.
Expand Down
Loading