diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 5b606616e97a..c022fab158c8 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -56,7 +56,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) - scale_inv = 1.0 / scale ret = (scale * inp.float()).to(fp8_type) - return ret, scale_inv + return ret, torch.unsqueeze(scale_inv, dim=0) def cast_from_fp8(