Skip to content

Commit

Permalink
all_gather only internode, fix pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
GuangyaoZhang committed Sep 11, 2024
1 parent 13946c4 commit 01e7f59
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 19 deletions.
15 changes: 12 additions & 3 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,8 +718,8 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async
return _all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format, async_op=async_op)


def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:

@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def _all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
world_size = dist.get_world_size(group)

input_type = input_.dtype
Expand All @@ -743,8 +743,17 @@ def cast_op():
cast_op()


@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
if process_group_is_intranode(group):
return dist.all_gather(output_list, input_, group=group, async_op=async_op)
else:
return _all_gather_fp8(output_list, input_, group=group, fp8_format=fp8_format, async_op=async_op)


@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def all_gather_fp8_lagacy(
output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False
) -> Optional[Handle]:
world_size = dist.get_world_size(group)
shape = input_.shape
input_type = input_.dtype
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
_grad_accum_fusion_available = False

from colossalai.quantization.fp8 import (
all_gather_fp8,
all_reduce_fp8,
all_to_all_fp8,
all_to_all_single_fp8,
gather_fp8,
reduce_scatter_fp8,
)

Expand Down Expand Up @@ -961,7 +961,7 @@ def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_for
input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
if fp8_communication:
gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group)
all_gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group)
else:
dist.all_gather(tensor_list, input_, group=process_group)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_fp8/test_fp8_all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_to_all_fp8
from colossalai.quantization.fp8 import _all_to_all_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


Expand All @@ -20,7 +20,7 @@ def check_4gpu(shape, scatter_dim, dtype, fp8_format):
input_tensor_list = [x.contiguous() for x in input_tensor_list]
output_tensor_list_fp8 = [torch.empty_like(x) for x in input_tensor_list]
output_tensor_list = [torch.empty_like(x) for x in input_tensor_list]
all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format)
_all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format)
dist.all_to_all(output_tensor_list, input_tensor_list, group=_get_default_group())
assert_close(output_tensor_list_fp8, output_tensor_list, rtol=0.1, atol=0.1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,13 @@

from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import gather_fp8
from colossalai.quantization.fp8 import _all_gather_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


@parameterize(
"shape",
[
(3, 7),
(2, 1),
(1, 2),
(2, 2),
(4, 2),
(5,),
(4,),
(2,),
],
[(3, 7, 16)],
)
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
Expand All @@ -30,7 +21,9 @@ def check_4gpu(shape, dtype, fp8_format, async_op):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
output_list = [torch.empty_like(x) for _ in range(world_size)]
output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)]
fp8_handle = gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op)
fp8_handle = _all_gather_fp8(
output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op
)
origin_hanle = dist.all_gather(output_list, x, group=_get_default_group(), async_op=async_op)
if async_op:
fp8_handle.wait()
Expand Down

0 comments on commit 01e7f59

Please sign in to comment.