Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[FSDP2] precompute scale after optimizer.step for dynamic scaling #266

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9d5595c
[FSDP2] set vocab_size=32 to avoid must be divisible by 16 error
weifengpy May 21, 2024
e7005c2
precast after optimizer.step and dump profiler traces
weifengpy May 21, 2024
e41d589
Merge branch 'main' into fsdp2
weifengpy May 21, 2024
e0bee10
precast and preamax unit test
weifengpy May 24, 2024
c0ba5a2
remove duplicate vocab
weifengpy May 24, 2024
8da238e
fused amax
weifengpy May 30, 2024
ffff5ed
Merge branch 'main' into fsdp2
weifengpy Jun 6, 2024
aefa21b
use FP8_TYPES and max
weifengpy Jun 6, 2024
d4a1db7
commit all changes before cleaning
weifengpy Jun 6, 2024
d36e79b
pre_compute and flatten / unflatten
weifengpy Jun 6, 2024
6f244a2
remove unused constant
weifengpy Jun 6, 2024
dc5eab0
torch.compile works
weifengpy Jun 6, 2024
546e979
eager ready
weifengpy Jun 6, 2024
229ede6
linter
weifengpy Jun 6, 2024
d5b3ff6
linter
weifengpy Jun 6, 2024
4f05e04
flatten tensor
weifengpy Jun 25, 2024
3de59af
commit all changes for review before rebasing
weifengpy Jul 8, 2024
ffcd197
rebase on unified float8linear
weifengpy Jul 9, 2024
6b18947
Merge branch 'pytorch-labs:main' into fsdp2
weifengpy Jul 9, 2024
562424c
move precompute to fsdp_utils.py
weifengpy Jul 9, 2024
75e0e45
simplify amax calc
weifengpy Jul 9, 2024
fe95f8b
explain _pre_computed_amax
weifengpy Jul 9, 2024
1cbaa13
fix linter
weifengpy Jul 9, 2024
fe2e0a0
document precompute_float8_amax_for_fsdp
weifengpy Jul 9, 2024
e4eaa2a
rename pre_compute to precompute
weifengpy Jul 9, 2024
e4245e4
Merge branch 'main' into fsdp2
weifengpy Jul 10, 2024
e12c973
remove clamp_amax=True/False
weifengpy Jul 10, 2024
9ef67fb
precompute scale
weifengpy Jul 10, 2024
fa2f08a
unit test for precomputing scales
weifengpy Jul 10, 2024
ba085e5
add precompute scale in README
weifengpy Jul 10, 2024
ac0afb0
rename to precompute_float8_dynamic_scale_for_fsdp
weifengpy Jul 11, 2024
8e56dfc
rename to precompute_float8_dynamic_scale_for_fsdp
weifengpy Jul 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ This is the most accurate recipe as every tensor is scaled dynamically.
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
from float8_experimental.float8_linear import Float8Linear

# create model
Expand All @@ -51,7 +52,18 @@ model = FSDP(model, use_orig_params=True)
# optional: enable torch.compile for improved performance
m = torch.compile(m)

# train/finetune (not shown)
# toy training loop
for _ in range(N_ITER):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()

# specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on
# this method is optional but is highly recommended for performance
# it calcuclates scales for all parameters in a single all-reduce
precompute_float8_dynamic_scale_for_fsdp(model)

```

## float8 linear with delayed scaling
Expand Down
43 changes: 36 additions & 7 deletions float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ def cast_to_float8_e5m2_dynamic_bw(

class WeightWithDynamicFloat8CastTensor(torch.Tensor):
@staticmethod
def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
def __new__(
cls,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
precomputed_scale: Optional[torch.Tensor] = None,
):
return torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
Expand All @@ -96,9 +101,18 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
requires_grad=tensor.requires_grad,
)

def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
def __init__(
self,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
precomputed_scale: Optional[torch.Tensor] = None,
):
self._tensor = tensor
self._mm_config = mm_config
# for dynamic scaling
# `precompute_float8_dynamic_scale_for_fsdp` calculates scales
# for all float8 parameters after optimizer step
self._precomputed_scale = precomputed_scale

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
Expand Down Expand Up @@ -127,20 +141,35 @@ def unwrap(t):
)

def __tensor_flatten__(self):
return ["_tensor"], self._mm_config
if self._precomputed_scale:
return ["_tensor", "_precomputed_scale"], self._mm_config
else:
return ["_tensor"], self._mm_config

@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
mm_config = flatten_spec
return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config)
return WeightWithDynamicFloat8CastTensor(
inner_tensors["_tensor"],
mm_config,
getattr(inner_tensors, "_precomputed_scale", None),
)

def __repr__(self):
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"

def fsdp_pre_all_gather(self, mesh):
float8_tensor = cast_to_float8_e4m3_dynamic(
self._tensor, self._mm_config, reduce_amax=True
)
if self._precomputed_scale is not None:
float8_tensor = Float8Tensor.to_float8(
self._tensor,
self._precomputed_scale,
torch.float8_e4m3fn,
mm_config=self._mm_config,
)
else:
float8_tensor = cast_to_float8_e4m3_dynamic(
self._tensor, self._mm_config, reduce_amax=True
)
return (float8_tensor._data,), (float8_tensor._scale,)

def fsdp_post_all_gather(
Expand Down
52 changes: 52 additions & 0 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import math
from typing import List

import torch
import torch.nn as nn
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_utils import EPS


@torch.no_grad()
def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

improve docstring with example API usage

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! can we add this to the README?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just added API usage to README

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can make sure dynamic is in the name, since this is specific to dynamic scaling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renaming to precompute_float8_dynamic_scale_for_fsdp

Calculate scale dynamically for all float8 parameters.
This should be run after the optimizer step. It performs a single all-reduce to compute the
scales for all float8 weights.
Example usage:
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
"""
from torch.distributed._tensor import DTensor

if any(
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED
for m in module.modules()
):
raise NotImplementedError("Only supports delayed scaling")
float8_linears: List[Float8Linear] = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this expensive for real models? if yes, maybe we can offer option to precompute this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intuition is that this should be pretty fast as the number of nn.Modules in the model is usually at most in the thousands and this is pure Python overhead. @weifengpy you can check the traces you have if you see any noticeable gaps from this.

Copy link
Contributor Author

@weifengpy weifengpy Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just checked the profiler traces. it's roughly 0.15ms cpu overhead (5% of precompute_float8_dynamic_scale_for_fsdp and is tiny portion of 1 training loop). no cuda are launched

thus I am keeping it as is now for simplicity
Screenshot 2024-07-11 at 2 45 17 PM

m
for m in module.modules()
if isinstance(m, Float8Linear)
and isinstance(m.weight, DTensor)
and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor)
]
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]

if not weights:
return

# inf-norm is equivalent to max(abs(w))
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
amax_tensor = torch.vstack(max_weights) # Partial
# clamp is dispatched through DTensor
# it will issue a single all-reduce
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
if amax_tensor.dtype is torch.float16:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
scales = torch.split(scale_tensor, 1) # Replicate
for scale, float8_linear in zip(scales, float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor
4 changes: 4 additions & 0 deletions test/test_fsdp2/test_fsdp2_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp


def check_parity_no_mp(
Expand All @@ -15,6 +16,7 @@ def check_parity_no_mp(
fsdp_model: nn.Module,
fsdp_optim: torch.optim.Optimizer,
local_inp: torch.Tensor,
precompute: bool = False,
):
for iter_idx in range(10):
losses: List[torch.Tensor] = []
Expand All @@ -28,6 +30,8 @@ def check_parity_no_mp(
param.grad.div_(dist.get_world_size())
# TODO(future): add amax syncing once delayed scaling is supported
optim.step()
if model is fsdp_model and precompute:
precompute_float8_dynamic_scale_for_fsdp(model)
test_cls.assertEqual(losses[0], losses[1])


Expand Down
21 changes: 17 additions & 4 deletions test/test_fsdp2/test_fsdp2_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,21 @@ def world_size(self) -> int:

@skip_if_lt_x_gpu(2)
def test_transformer_parity_dynamic(self):
for enable_fsdp_fp8_all_gather in [False, True]:
self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather)
self.run_subtests(
{
"enable_fsdp_fp8_all_gather": [False, True],
"precompute": [False, True],
},
self._test_transformer_parity_dynamic,
)

def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
def _test_transformer_parity_dynamic(
self,
enable_fsdp_fp8_all_gather: bool,
precompute: bool,
):
if not enable_fsdp_fp8_all_gather and precompute:
return
# NOTE: Weight-tying does not compose with fp8 all-gather because the
# embedding weight and output linear weight are tied but only the
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
Expand All @@ -109,7 +120,9 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
local_inp = torch.randint(
0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda"
)
check_parity_no_mp(self, ref_module, ref_optim, module, optim, local_inp)
check_parity_no_mp(
self, ref_module, ref_optim, module, optim, local_inp, precompute
)

@skip_if_lt_x_gpu(2)
def test_transformer_memory(self):
Expand Down
Loading