Skip to content

Commit

Permalink
[FP8] unsqueeze scale to make it compatible with torch.compile (hpcai…
Browse files Browse the repository at this point in the history
  • Loading branch information
GuangyaoZhang authored Aug 29, 2024
1 parent 0d3a85d commit e96a076
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
scale_inv = 1.0 / scale

ret = (scale * inp.float()).to(fp8_type)
return ret, scale_inv
return ret, torch.unsqueeze(scale_inv, dim=0)


def cast_from_fp8(
Expand Down

0 comments on commit e96a076

Please sign in to comment.