Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disable all_to_all_fp8 in intranode #6045

Merged
merged 6 commits into from
Sep 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 73 additions & 6 deletions colossalai/quantization/fp8.py
BurkeHulk marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Any, Optional, Tuple

import numpy as np
Expand All @@ -23,6 +24,24 @@ def wait(self):
self.remain_ops()


def process_group_is_intranode(pg):
if pg is None:
from torch.distributed.distributed_c10d import _get_default_group

pg = _get_default_group()

local_world_size = None
for var in ["LOCAL_WORLD_SIZE", "OMPI_COMM_WORLD_LOCAL_SIZE", "SLURM_TASKS_PER_NODE"]:
if var in os.environ:
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
if local_world_size is None:
local_world_size = torch.cuda.device_count()

group_ranks = dist.get_process_group_ranks(pg)
group_ranks_node_ids = [rank // local_world_size for rank in group_ranks]
return min(group_ranks_node_ids) == max(group_ranks_node_ids)


def cast_to_fp8(
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -92,7 +111,7 @@ def cast_from_fp8(
return ret.to(ret_type)


def all_reduce_fp8(
def _all_reduce_fp8(
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
) -> Optional[Handle]:
r"""
Expand Down Expand Up @@ -159,7 +178,15 @@ def cat_op():
cat_op()


def all_to_all_single_fp8(
def all_reduce_fp8(
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
) -> Optional[Handle]:
# fall back to default op due to performance issue
return dist.all_reduce(tensor, op=op, group=group, async_op=async_op)


@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def _all_to_all_single_fp8(
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
) -> Optional[Handle]:
r"""
Expand Down Expand Up @@ -222,6 +249,33 @@ def cast_op():
cast_op()


def all_to_all_single_fp8(
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
) -> Optional[Handle]:
r"""
This is wrapper for _all_to_all_single_fp8.
"""
if process_group_is_intranode(group):
return dist.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=async_op,
)
else:
return _all_to_all_single_fp8(
output,
input,
fp8_format=fp8_format,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=async_op,
)


def cast_to_fp8_pipeline(inp: Any) -> None:
"""
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
Expand Down Expand Up @@ -293,7 +347,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
del inp["dtype"]


def reduce_scatter_fp8(
def _reduce_scatter_fp8(
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
) -> Optional[Handle]:
r"""
Expand Down Expand Up @@ -338,6 +392,13 @@ def cast_op():
cast_op()


def reduce_scatter_fp8(
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
) -> Optional[Handle]:
# fall back to default op due to performance issue
return dist.reduce_scatter(output, input_list, group=group, async_op=async_op)


def fp8_compress_ddp_grad_comm_hook_async(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
Expand Down Expand Up @@ -617,10 +678,9 @@ def cast_op():
cast_op()


def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):

@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def _all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
world_size = dist.get_world_size(group)

input_type = input_list[0].dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
scale_list = []
Expand Down Expand Up @@ -651,6 +711,13 @@ def cast_op():
cast_op()


def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
if process_group_is_intranode(group):
return dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
else:
return _all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format, async_op=async_op)


def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:

world_size = dist.get_world_size(group)
Expand Down
Loading