Skip to content

Commit

Permalink
[BE][3/n] wrap fp8 logic using Float8Handler
Browse files Browse the repository at this point in the history
ghstack-source-id: e94c7f6f4fad87c5432262c54beabd02de5541b8
Pull Request resolved: #496
  • Loading branch information
tianyu-l committed Jul 31, 2024
1 parent 389116b commit a713124
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 150 deletions.
14 changes: 7 additions & 7 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@

from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_tokenizer
from torchtitan.float8_linear import (
maybe_build_fp8_linear,
maybe_precompute_fp8_dynamic_scale_for_fsdp,
)
from torchtitan.float8_linear import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
Expand Down Expand Up @@ -127,8 +124,10 @@ def loss_fn(pred, labels):
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)

# a no-op hander if fp8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
# swap to Float8Linear base on fp8 config
maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)
float8_handler.convert_to_float8_training(whole_model)

# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [whole_model]
Expand Down Expand Up @@ -184,13 +183,14 @@ def loss_fn(pred, labels):
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model)
# optimizer step
optimizers.step()
lr_schedulers.step()
# when fp8 config is on,
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
maybe_precompute_fp8_dynamic_scale_for_fsdp(whole_model, job_config)
float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model)
optimizers.zero_grad()
print(f"Peak Memory at iter: {iter_idx}")
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
Expand Down
83 changes: 43 additions & 40 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,46 +348,6 @@ def __init__(self):
action="store_true",
help="Whether to compile the model",
)
self.parser.add_argument(
"--training.enable_float8_linear",
action="store_true",
help="""
If true, swaps `torch.nn.Linear` with `Float8Linear`.
This feature requires you to install 'torchao' which can be found
here: https://github.com/pytorch/ao
""",
)
self.parser.add_argument(
"--training.enable_fsdp_float8_all_gather",
action="store_true",
default=False,
help="Whether enable float8 all-gather in FSDP",
)
self.parser.add_argument(
"--training.precompute_float8_dynamic_scale_for_fsdp",
action="store_true",
default=False,
help="Whether precompute float8 scales dynamically for FSDP",
)
self.parser.add_argument(
"--training.float8_scaling_type_input",
type=str,
default="dynamic",
help="float8 scaling for input, dynamic (default) or delayed",
choices=["dynamic", "delayed"],
)
self.parser.add_argument(
"--training.float8_scaling_type_weight",
type=str,
default="dynamic",
help="float8 scaling for input, dynamic (default) or delayed",
)
self.parser.add_argument(
"--training.float8_scaling_type_grad_output",
type=str,
default="dynamic",
help="float8 scaling for input, dynamic (default) or delayed",
)
self.parser.add_argument(
"--training.gc_freq",
type=int,
Expand Down Expand Up @@ -483,6 +443,7 @@ def __init__(self):
0 is the default value.
""",
)

# activation checkpointing configs
self.parser.add_argument(
"--activation_checkpoint.mode",
Expand All @@ -500,6 +461,48 @@ def __init__(self):
""",
)

# float8 configs
self.parser.add_argument(
"--float8.enable_float8_linear",
action="store_true",
help="""
If true, swaps `torch.nn.Linear` with `Float8Linear`.
This feature requires you to install 'torchao' which can be found
here: https://github.com/pytorch/ao
""",
)
self.parser.add_argument(
"--float8.enable_fsdp_float8_all_gather",
action="store_true",
default=False,
help="Whether enable float8 all-gather in FSDP",
)
self.parser.add_argument(
"--float8.precompute_float8_dynamic_scale_for_fsdp",
action="store_true",
default=False,
help="Whether precompute float8 scales dynamically for FSDP",
)
self.parser.add_argument(
"--float8.scaling_type_input",
type=str,
default="dynamic",
help="float8 scaling for input, dynamic (default) or delayed",
choices=["dynamic", "delayed"],
)
self.parser.add_argument(
"--float8.scaling_type_weight",
type=str,
default="dynamic",
help="float8 scaling for input, dynamic (default) or delayed",
)
self.parser.add_argument(
"--float8.scaling_type_grad_output",
type=str,
default="dynamic",
help="float8 scaling for input, dynamic (default) or delayed",
)

# communications library settings
self.parser.add_argument(
"--comm.init_timeout_seconds",
Expand Down
173 changes: 87 additions & 86 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,127 +12,128 @@

# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
import functools
from typing import Optional

import torch
import torch.nn as nn
from torch._logging import warning_once

from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.parallelisms import ParallelDims


@functools.lru_cache(None)
def is_sm90_or_later():
# Float8 is only supported on H100+ GPUs
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)


def maybe_build_fp8_linear(
model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False
):
"""
This function converts the linear layers to `Float8Linear`. Note that today,
only dynamic tensor scaling (the default) is supported.
This will mutate the model inplace.
"""
enable_float8_linear = job_config.training.enable_float8_linear
if not enable_float8_linear:
return
if not is_sm90_or_later():
warning_once(
logger,
"Failed to swap to Float8Linear because SM90 or later is not available",
)
return
try:
from torchao.float8 import (
CastConfig,
convert_to_float8_training,
Float8LinearConfig,
ScalingType,
)
class Float8Handler:
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False

float8_config = job_config.float8
if not float8_config.enable_float8_linear:
return
if not is_sm90_or_later():
logger.warning(
"Failed to swap to Float8Linear because SM90 or later is not available",
)
return
try:
from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType
except ImportError as e:
raise ImportError(
"torchao is not installed. Please install it to use fp8 linear layers."
) from e

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
job_config.training.enable_fsdp_float8_all_gather and dp_enabled
)
scaling_type_input = ScalingType(job_config.training.float8_scaling_type_input)
scaling_type_weight = ScalingType(
job_config.training.float8_scaling_type_weight
parallel_dims.dp_enabled
and parallel_dims.dp_type == "fsdp"
and float8_config.enable_fsdp_float8_all_gather
)
scaling_type_grad_output = ScalingType(
job_config.training.float8_scaling_type_grad_output
)
float8_config = Float8LinearConfig(
scaling_type_input = ScalingType(float8_config.scaling_type_input)
scaling_type_weight = ScalingType(float8_config.scaling_type_weight)
scaling_type_grad_output = ScalingType(float8_config.scaling_type_grad_output)
self.config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
cast_config_input=CastConfig(scaling_type=scaling_type_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
enable_pre_and_post_forward=False,
)

self.enabled = True

# for precompute_fp8_dynamic_scale_for_fsdp
self.precompute_scale = (
enable_fsdp_float8_all_gather
and float8_config.precompute_float8_dynamic_scale_for_fsdp
)

# for sync_float8_amax_and_scale_history
self.delayed_scaling = (
scaling_type_input == "delayed"
or scaling_type_weight == "delayed"
or scaling_type_grad_output == "delayed"
)
self._sync_float8_amax_and_scale_history = None
self.compile = job_config.training.compile

logger.info("Float8 training active")

def convert_to_float8_training(self, model: nn.Module):
"""
This function converts the linear layers of `model` to `Float8Linear`.
Note that today, only dynamic tensor scaling (the default) is supported.
This will mutate the model inplace.
"""
if not self.enabled:
return

from torchao.float8 import convert_to_float8_training

# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
convert_to_float8_training(
model,
config=float8_config,
config=self.config,
module_filter_fn=lambda mod, fqn: fqn != "output",
)
logger.info(
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
f"{self.config.enable_fsdp_float8_all_gather}"
)
except ImportError as exc:
raise ImportError(
"torchao is not installed. Please install it to use fp8 linear layers."
) from exc


def maybe_precompute_fp8_dynamic_scale_for_fsdp(
model: nn.Module, job_config: JobConfig
):
if not (
job_config.training.enable_float8_linear
and job_config.training.enable_fsdp_float8_all_gather
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
):
return
if not is_sm90_or_later():
warning_once(
logger,
"Skipped precomputing fp8 scales because SM90 or later is not available",
)
return
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp

precompute_float8_dynamic_scale_for_fsdp(model)
def precompute_fp8_dynamic_scale_for_fsdp(self, model: nn.Module):
if not self.enabled:
return

if not self.precompute_scale:
return

_sync_float8_amax_and_scale_history = None
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp

precompute_float8_dynamic_scale_for_fsdp(model)

def maybe_sync_float8_amax_and_scale_history(model: nn.Module, job_config: JobConfig):
if not (
job_config.training.enable_float8_linear
and (
job_config.training.float8_scaling_type_input == "delayed"
or job_config.training.float8_scaling_type_weight == "delayed"
or job_config.training.float8_scaling_type_grad_output == "delayed"
)
):
return
def sync_float8_amax_and_scale_history(self, model: nn.Module):
if not self.enabled:
return

from torchao.float8 import sync_float8_amax_and_scale_history
if not self.delayed_scaling:
return

# TODO(future): see if precalculating the modules to sync over is going to
# meaningfully help performance
from torchao.float8 import sync_float8_amax_and_scale_history

global _sync_float8_amax_and_scale_history
if _sync_float8_amax_and_scale_history is None:
if job_config.training.compile:
_sync_float8_amax_and_scale_history = torch.compile(
sync_float8_amax_and_scale_history
)
else:
_sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history
# TODO(vkuzo): see if precalculating the modules to sync over is going to
# meaningfully help performance

if self._sync_float8_amax_and_scale_history is None:
if self.compile:
self._sync_float8_amax_and_scale_history = torch.compile(
sync_float8_amax_and_scale_history
)
else:
self._sync_float8_amax_and_scale_history = (
sync_float8_amax_and_scale_history
)

sync_float8_amax_and_scale_history(model)
self._sync_float8_amax_and_scale_history(model)
2 changes: 1 addition & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def parallelize_llama(
model,
world_mesh["tp"],
loss_parallel=parallel_dims.loss_parallel_enabled,
enable_float8=job_config.training.enable_float8_linear,
enable_float8=job_config.float8.enable_float8_linear,
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
)

Expand Down
Loading

0 comments on commit a713124

Please sign in to comment.