Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Support mma -> mma conversion on Ampere and Turing architectures #2627

Merged
merged 3 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
auto srcDotLayout = srcLayout.dyn_cast<DotOperandEncodingAttr>();
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
auto dstDotLayout = dstLayout.dyn_cast<DotOperandEncodingAttr>();
assert(!(srcMmaLayout && dstMmaLayout) &&
"Unexpected mma -> mma layout conversion");

assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) &&
"mma -> mma layout conversion is only supported on Ampere");

// mma or dot layout does not have an order, so the order depends on the
// layout of the other operand.
auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout)
Expand Down
109 changes: 101 additions & 8 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ class MmaLayout:

def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape):
self.version = version
self.warps_per_cta = str(warps_per_cta)
self.warps_per_cta = warps_per_cta
self.ctas_per_cga = str(ctas_per_cga)
self.cta_split_num = str(cta_split_num)
self.cta_order = str(cta_order)
self.instr_shape = str(instr_shape)

def __str__(self):
return f"#{GPU_DIALECT}.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>"
return f"#{GPU_DIALECT}.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={str(self.warps_per_cta)}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>"
Jokeren marked this conversation as resolved.
Show resolved Hide resolved


class BlockedLayout:
Expand Down Expand Up @@ -3714,10 +3714,6 @@ def kernel(Out):
# TODO: backend should be tested separately

layouts = [
# MmaLayout(1, [1, 4], [1, 1], [0, 1]),
# MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),
# MmaLayout(1, [4, 1], [1, 1], [0, 1]),
# MmaLayout((2, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]),
BlockedLayout([1, 16], [8, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
Expand Down Expand Up @@ -3748,8 +3744,6 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
pytest.skip("Out of bound access when maxPhase > 1")
if str(src_layout) == str(dst_layout):
pytest.skip()
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
pytest.skip()

layouts = f"""
#src = {src_layout}
Expand Down Expand Up @@ -3811,6 +3805,105 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
assert torch.equal(z, x)


mma_pairs = [
[
MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),
MmaLayout((2, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]),
],
[
MmaLayout((2, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]),
MmaLayout((2, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]),
],
[
MmaLayout((2, 1), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),
MmaLayout((2, 1), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]),
],
[
MmaLayout((2, 1), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]),
MmaLayout((2, 1), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]),
],
# Mma -> mma support is TODO on Hopper (and Volta)
# [
# MmaLayout((3, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8, 16]),
# MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8, 16]),
# ],
# [
# MmaLayout((3, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8, 16]),
# MmaLayout((3, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8, 16]),
# ],
# [
# MmaLayout((3, 1), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8, 16]),
# MmaLayout((3, 1), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8, 16]),
# ],
# [
# MmaLayout((3, 1), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8, 16]),
# MmaLayout((3, 1), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8, 16]),
# ],
]


@pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("mma_pair", mma_pairs)
def test_convertmma2mma(M, N, mma_pair, dtype, device):
if is_hip():
pytest.skip("test_mma2mma is not supported in HIP")

src_layout, _ = mma_pair
num_warps = np.cumprod(src_layout.warps_per_cta)[-1]

def do_test(src_layout, dst_layout):
layouts = f"""
#src = {src_layout}
#dst = {dst_layout}
"""

conversion = f"""
%12 = triton_gpu.convert_layout %9 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #dst>
"""

ir = layouts + f"""
module attributes {{"triton_gpu.num-warps" = {num_warps} : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16, 1> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16, 1> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%2 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<{M}x{N}x!tt.ptr<f16, 1>, #src>
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src>
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
%7 = tt.broadcast %6 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%8 = tt.broadcast %5 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src>
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16, 1>, #src>, tensor<{M}x{N}xi32, #src>
%11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #src>
%3 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<{M}x{N}x!tt.ptr<f16, 1>, #dst>
""" + conversion + f"""
%14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr<f16, 1>, #dst>, tensor<{M}x{N}xi32, #dst>
tt.store %14, %13 : tensor<{M}x{N}xf16, #dst>
tt.return
}}
}}
"""

x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device)
z = torch.empty_like(x)

# write the IR to a temporary file using mkstemp
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)
kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr())

assert torch.equal(z, x)

do_test(mma_pair[0], mma_pair[1])
do_test(mma_pair[1], mma_pair[0])


def test_load_scalar_with_mask(device):

@triton.jit
Expand Down
Loading