Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 #6059

Merged
merged 8 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 19 additions & 78 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
SCALE_BYTES = 4
try:
cuda_arch = int("".join(str(i) for i in torch.cuda.get_device_capability()))
except:
cuda_arch = 0


class Handle:
Expand Down Expand Up @@ -185,7 +189,7 @@ def all_reduce_fp8(
return dist.all_reduce(tensor, op=op, group=group, async_op=async_op)


@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
GuangyaoZhang marked this conversation as resolved.
Show resolved Hide resolved
def _all_to_all_single_fp8(
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
) -> Optional[Handle]:
Expand Down Expand Up @@ -606,79 +610,7 @@ def split_chunk_by_channel(
return chunk.split(sizes)


def all_gather_into_tensor_flat_fp8(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
output_shape: torch.Size,
group: dist.ProcessGroup,
fp8_format: str = "e4m3",
async_op: bool = False,
) -> Optional[Handle]:
"""all gather into tensor in fp8 format

Args:
output_tensor (torch.Tensor): output tensor, which is flattened
input_tensor (torch.Tensor): input tensor, which is flattened
group (dist.ProcessGroup): process group
fp8_format (str, optional): fp8 format, e4m3 or e5m2. Defaults to "e4m3".
"""
assert input_tensor.dim() == 1 and output_tensor.dim() == 1, "input/output tensor should be flattened"
world_size = dist.get_world_size(group)
assert (
output_tensor.numel() == input_tensor.numel() * world_size
), "output tensor size should be world_size times of input tensor size"

input_type = output_tensor.dtype

fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max

if len(output_shape) == 2:
per_channel_max = torch.zeros(output_shape[0], device=output_tensor.device, dtype=torch.float)
num_channels, channel_size = output_shape
rank = dist.get_rank(group)
channel_start_idx = (input_tensor.numel() * rank) // channel_size
per_channel_splits = split_chunk_by_channel(input_tensor, channel_size, num_channels, rank, world_size)
for i, per_channel_split in enumerate(per_channel_splits):
idx = i + channel_start_idx
if idx < num_channels:
per_channel_max[idx] = per_channel_split.abs().max().float()
dist.all_reduce(per_channel_max, op=dist.ReduceOp.MAX, group=group)
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max
fp8_input = input_tensor.float()
fp8_per_channel_splits = split_chunk_by_channel(fp8_input, channel_size, num_channels, rank, world_size)
for i, per_channel_split in enumerate(fp8_per_channel_splits):
idx = i + channel_start_idx
if idx < num_channels:
per_channel_split.mul_(scale[idx])
fp8_input = fp8_input.to(fp8_type)
else:
per_tensor_max = input_tensor.abs().max().float()
dist.all_reduce(per_tensor_max, op=dist.ReduceOp.MAX, group=group)
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
fp8_input = (scale * input_tensor.float()).to(fp8_type)
scale_inv = 1.0 / scale

buffer = torch.empty_like(output_tensor, dtype=fp8_type)
tensor_handle = dist.all_gather_into_tensor(
buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group, async_op=async_op
)

def cast_op():
numel = output_shape.numel()
valid_buffer = buffer[:numel].reshape(output_shape)
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
output_tensor[:numel].copy_(valid_buffer.view(-1))

if async_op:
return Handle([tensor_handle], cast_op)
else:
cast_op()


@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
def _all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
world_size = dist.get_world_size(group)
input_type = input_list[0].dtype
Expand Down Expand Up @@ -718,8 +650,8 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async
return _all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format, async_op=async_op)


def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:

@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
def _all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
world_size = dist.get_world_size(group)

input_type = input_.dtype
Expand All @@ -743,8 +675,17 @@ def cast_op():
cast_op()


@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
if process_group_is_intranode(group):
return dist.all_gather(output_list, input_, group=group, async_op=async_op)
else:
return _all_gather_fp8(output_list, input_, group=group, fp8_format=fp8_format, async_op=async_op)


@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
def all_gather_fp8_lagacy(
output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False
) -> Optional[Handle]:
world_size = dist.get_world_size(group)
shape = input_.shape
input_type = input_.dtype
Expand All @@ -769,7 +710,7 @@ def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op:
# out.copy_(output[i].view(shape))


@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
world_size = dist.get_world_size(group)
rank = dist.get_rank(group)
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
_grad_accum_fusion_available = False

from colossalai.quantization.fp8 import (
all_gather_fp8,
all_reduce_fp8,
all_to_all_fp8,
all_to_all_single_fp8,
gather_fp8,
reduce_scatter_fp8,
)

Expand Down Expand Up @@ -961,7 +961,7 @@ def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_for
input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
if fp8_communication:
gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group)
all_gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group)
else:
dist.all_gather(tensor_list, input_, group=process_group)

Expand Down
12 changes: 7 additions & 5 deletions colossalai/zero/gemini/chunk/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.distributed import ProcessGroup

from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_gather_fp8


class TensorState(Enum):
Expand Down Expand Up @@ -523,11 +524,12 @@ def __gather(self, async_op: bool = False) -> Optional[dist.Work]:
alloc_storage(self.cuda_global_chunk)
assert self.cuda_global_chunk.is_contiguous()
if self.fp8_communication:
assert async_op == False, "fp8 all-gather does not support async_op!"
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8

work = all_gather_into_tensor_flat_fp8(
self.cuda_global_chunk, self.cuda_shard, self.cuda_global_chunk.shape, self.torch_pg
work = all_gather_fp8(
self.cuda_global_chunk.chunk(self.pg_size),
self.cuda_shard,
self.torch_pg,
fp8_format="e4m3",
async_op=async_op,
)
else:
work = dist.all_gather_into_tensor(
Expand Down
4 changes: 2 additions & 2 deletions colossalai/zero/low_level/bookkeeping/tensor_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
from colossalai.quantization.fp8 import all_gather_fp8


class TensorBucket:
Expand Down Expand Up @@ -67,7 +67,7 @@ def all_gather(self, group=None, fp8_communication: bool = False):
flat = self.flatten()
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
if fp8_communication:
all_gather_into_tensor_flat_fp8(buffer, flat, output_shape=buffer.shape, group=group)
all_gather_fp8(buffer.chunk(dist.get_world_size(group)), flat, group=group, fp8_format="e4m3")
GuangyaoZhang marked this conversation as resolved.
Show resolved Hide resolved
else:
dist.all_gather_into_tensor(buffer, flat, group=group)
GuangyaoZhang marked this conversation as resolved.
Show resolved Hide resolved
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
Expand Down
9 changes: 6 additions & 3 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8, all_reduce_fp8, reduce_scatter_fp8
from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8
from colossalai.tensor.moe_tensor.api import is_moe_tensor

from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
Expand Down Expand Up @@ -580,8 +580,11 @@ def step(self, closure=None):
else:
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
if self._fp8_communication:
all_gather_into_tensor_flat_fp8(
padded_working_param, param_to_gather, pg, fp8_format="e4m3"
all_gather_fp8(
padded_working_param.chunk(dist.get_world_size(pg)),
GuangyaoZhang marked this conversation as resolved.
Show resolved Hide resolved
param_to_gather,
pg,
fp8_format="e4m3",
)
else:
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fp8/test_fp8_all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_to_all_fp8
from colossalai.quantization.fp8 import _all_to_all_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


Expand All @@ -20,7 +20,7 @@ def check_4gpu(shape, scatter_dim, dtype, fp8_format):
input_tensor_list = [x.contiguous() for x in input_tensor_list]
output_tensor_list_fp8 = [torch.empty_like(x) for x in input_tensor_list]
output_tensor_list = [torch.empty_like(x) for x in input_tensor_list]
all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format)
_all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format)
dist.all_to_all(output_tensor_list, input_tensor_list, group=_get_default_group())
assert_close(output_tensor_list_fp8, output_tensor_list, rtol=0.1, atol=0.1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,13 @@

from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import gather_fp8
from colossalai.quantization.fp8 import _all_gather_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


@parameterize(
"shape",
[
(3, 7),
(2, 1),
(1, 2),
(2, 2),
(4, 2),
(5,),
(4,),
(2,),
],
[(3, 7, 16)],
)
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
Expand All @@ -30,7 +21,9 @@ def check_4gpu(shape, dtype, fp8_format, async_op):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
output_list = [torch.empty_like(x) for _ in range(world_size)]
output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)]
fp8_handle = gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op)
fp8_handle = _all_gather_fp8(
output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op
)
origin_hanle = dist.all_gather(output_list, x, group=_get_default_group(), async_op=async_op)
if async_op:
fp8_handle.wait()
Expand Down
43 changes: 0 additions & 43 deletions tests/test_fp8/test_fp8_allgather_flat.py

This file was deleted.

Loading