Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fp8] optimize all-gather #6043

Merged
merged 4 commits into from
Sep 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 101 additions & 5 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.distributed import ReduceOp

SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
SCALE_BYTES = 4


class Handle:
Expand All @@ -22,7 +23,9 @@ def wait(self):
self.remain_ops()


def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
def cast_to_fp8(
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
Args:
Expand Down Expand Up @@ -55,12 +58,15 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
scale = fp8_max / per_tensor_max
scale_inv = 1.0 / scale

ret = (scale * inp.float()).to(fp8_type)
if out is not None:
ret = torch.mul(scale, inp.float(), out=out)
else:
ret = (scale * inp.float()).to(fp8_type)
return ret, torch.unsqueeze(scale_inv, dim=0)


def cast_from_fp8(
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False, out=None
) -> torch.Tensor:
r"""
Args:
Expand All @@ -74,9 +80,15 @@ def cast_from_fp8(
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")

if per_channel_scale:
ret = scale_inv[:, None] * inp.float()
if out is not None:
return torch.mul(scale_inv[:, None], inp.float(), out=out)
else:
ret = scale_inv[:, None] * inp.float()
else:
ret = scale_inv * inp.float()
if out is not None:
return torch.mul(scale_inv, inp.float(), out=out)
else:
ret = scale_inv * inp.float()
return ret.to(ret_type)


Expand Down Expand Up @@ -664,6 +676,90 @@ def cast_op():
cast_op()


@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
world_size = dist.get_world_size(group)
shape = input_.shape
input_type = input_.dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2

combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device)
combined_buffers = list(combined_buffer.chunk(world_size, dim=0))
cur_buffer = combined_buffers[dist.get_rank(group)]
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op)
for out, buf in zip(output_list, combined_buffers):
scale = buf[:SCALE_BYTES].clone().view(scale.dtype)
output = buf[SCALE_BYTES:].view(fp8_type)
cast_from_fp8(output.view(shape), scale, input_type, out=out)
# output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type)
# scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float)
# output = output.float() * scales
# for i, out in enumerate(output_list):
# out.copy_(output[i].view(shape))


@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
world_size = dist.get_world_size(group)
rank = dist.get_rank(group)

send_rank = (rank + 1) % world_size
recv_rank = (rank - 1) % world_size

shape = input_.shape
input_type = input_.dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2

combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device)
combined_buffers = list(combined_buffer.chunk(world_size, dim=0))
cur_buffer = combined_buffers[dist.get_rank(group)]
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale

def send_recv(idx):
send_idx = (rank - idx) % world_size
recv_idx = (rank - idx - 1) % world_size
ops = dist.batch_isend_irecv(
[
dist.P2POp(dist.isend, combined_buffers[send_idx], send_rank, group=group),
dist.P2POp(dist.irecv, combined_buffers[recv_idx], recv_rank, group=group),
]
)
return ops

def cast(idx):
cast_idx = (rank - idx - 1) % world_size
scale = combined_buffers[cast_idx][:SCALE_BYTES].clone().view(torch.float)
output = combined_buffers[cast_idx][SCALE_BYTES:].view(fp8_type)
cast_from_fp8(output.view(shape), scale, input_type, out=output_list[cast_idx])

# warmup
ops = send_recv(0)
output_list[rank].copy_(input_)
for op in ops:
op.wait()
ops = []

# 1p-1c
for i in range(1, world_size - 1):
new_ops = send_recv(i)
for op in ops:
op.wait()
cast(i - 1)
ops = new_ops

# cooldown
for op in ops:
op.wait()
cast(world_size - 2)


class _LinearFp8(torch.autograd.Function):
@staticmethod
def forward(
Expand Down
Loading