From 91445cf08780920a4098338b965f7be0406294cd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 13 Jan 2022 10:49:48 +0900 Subject: [PATCH] [CUTLASS] Support more kernels: int8, tf32, and 3xtf32 (#9899) * add int8 type in library * wip * adding test and plumbing data and weight dtype * adding 3xtf32 support and refactor tile description enum * add 3xtf32 test * update gemm generator too * int8 test worked * 3xtf32 also works * int8 and 3xtf32 gemm works * clean up test * support int8 in sm75 * refined int8 alignment constraints * black * support 3xtf32 in default kernel * remove log * refine dtype check * support tf32 * leave TODO for alignment modification on int8 kernels * tf32 test working * fix default kernel for tf32 * workaround for compilation failure * lint --- python/tvm/contrib/cutlass/build.py | 91 +++++++- python/tvm/contrib/cutlass/gen_conv2d.py | 29 ++- python/tvm/contrib/cutlass/gen_gemm.py | 66 ++++-- python/tvm/contrib/cutlass/gen_tensor_op.py | 210 +++++++++++++------ python/tvm/contrib/cutlass/library.py | 17 ++ python/tvm/relay/op/contrib/cutlass.py | 7 +- src/relay/backend/contrib/cutlass/codegen.cc | 6 +- tests/python/contrib/test_cutlass.py | 135 ++++++++++-- 8 files changed, 445 insertions(+), 116 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index e921302eafced..c919ff2833436 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -94,12 +94,25 @@ def visit_call(self, call): def select_gemm_kernel( - cutlass_profiler, op_type, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing + cutlass_profiler, + op_type, + MM, + KK, + NN, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + batched, + profile_all, + use_multiprocessing, ): """Run CUTLASS profiler to select the best kernel, or return the default one for dynamic workloads.""" if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]): - out = cutlass_profiler.get_default(op_type, out_dtype, batched=batched) + out = cutlass_profiler.get_default( + op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32, batched=batched + ) name, cutlass_op_def = out["name"], out["opdef"] logger.info("Picked the default kernel %s", name) else: @@ -109,6 +122,9 @@ def select_gemm_kernel( NN, KK, out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, batched=batched, profile_all=profile_all, use_multiprocessing=use_multiprocessing, @@ -122,7 +138,16 @@ def select_gemm_kernel( def handle_batch_matmul( - cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing + cutlass_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + profile_all, + use_multiprocessing, ): """Profile and select a kernel for batch_matmul op workload.""" MM = arg0_shape[1] @@ -130,7 +155,18 @@ def handle_batch_matmul( NN = arg1_shape[1] name, cutlass_op_def = select_gemm_kernel( - cutlass_profiler, op_type, MM, KK, NN, out_dtype, True, profile_all, use_multiprocessing + cutlass_profiler, + op_type, + MM, + KK, + NN, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + True, + profile_all, + use_multiprocessing, ) return { @@ -147,7 +183,16 @@ def handle_batch_matmul( def handle_dense( - cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing + cutlass_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + profile_all, + use_multiprocessing, ): """Profile and select a kernel for dense op workload.""" MM = arg0_shape[0] @@ -155,7 +200,18 @@ def handle_dense( NN = arg1_shape[0] name, cutlass_op_def = select_gemm_kernel( - cutlass_profiler, op_type, MM, KK, NN, out_dtype, False, profile_all, use_multiprocessing + cutlass_profiler, + op_type, + MM, + KK, + NN, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + False, + profile_all, + use_multiprocessing, ) assert "tn_align" in name, "Only supports (row_major, col_major) input layout for now." @@ -178,12 +234,15 @@ def handle_conv2d( strides, dilation, out_dtype, + data_dtype, + weight_dtype, + use_3xtf32, profile_all, use_multiprocessing, ): """Profile and select a kernel for conv2d op workload.""" if any(isinstance(s, tvm.tir.Any) for s in d_shape): - out = cutlass_profiler.get_default(op_type, out_dtype) + out = cutlass_profiler.get_default(op_type, out_dtype, data_dtype, weight_dtype, use_3xtf32) name, cutlass_op_def = out["name"], out["opdef"] logger.info("Picked the default kernel %s", name) else: @@ -195,6 +254,9 @@ def handle_conv2d( strides, dilation, out_dtype, + data_dtype, + weight_dtype, + use_3xtf32, profile_all=profile_all, use_multiprocessing=use_multiprocessing, ) @@ -209,7 +271,9 @@ def handle_conv2d( } -def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"): +def tune_cutlass_kernels( + mod, sm, use_3xtf32=True, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp" +): """Given a module partitioned for CUTLASS offloading, profile each workload to select which kernels to emit. @@ -258,6 +322,8 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t new_attrs.update(func.attrs) arg0_shape = new_attrs["arg0_shape"] arg1_shape = new_attrs["arg1_shape"] + arg0_dtype = new_attrs["arg0_dtype"] + arg1_dtype = new_attrs["arg1_dtype"] if "conv2d" in op_type: new_attrs["padding"] = annotator.op_attrs.padding @@ -273,6 +339,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t annotator.op_attrs.strides, annotator.op_attrs.dilation, out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, profile_all, use_multiprocessing, ) @@ -285,6 +354,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t arg0_shape, arg1_shape, out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, profile_all, use_multiprocessing, ) @@ -297,6 +369,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t arg0_shape, arg1_shape, out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, profile_all, use_multiprocessing, ) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 39db9fd013194..c09017adfd956 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -153,8 +153,13 @@ def __init__(self, sm, cutlass_path, binary_path): self.engine = ProfilerEngine(sm, cutlass_path, binary_path) self.cache = {} - def get_default(self, op_type, out_dtype): - gemm_profile_result = self.gemm_profiler.get_default(op_type, out_dtype) + def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32): + """Return the default kernel for the requested architecture. + For now, the default kernel was picked arbitrary. + """ + gemm_profile_result = self.gemm_profiler.get_default( + op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32 + ) tile_description = gemm_profile_result["tile_description"] alignment = gemm_profile_result["alignment"] data_type = gemm_profile_result["data_type"] @@ -165,9 +170,10 @@ def get_default(self, op_type, out_dtype): def check_align(self, op_name, C, K): """Filter out kernels that cannot be supported.""" - aligns = re.findall(r"align[1|2|4|8]", op_name) - assert len(aligns) == 1 - align = int(aligns[0][-1]) + match = re.match(".*_align([1-9]+)", op_name) + assert match is not None and len(match.groups()) == 1 + # The same alignment is used for all axes + align = int(match.groups()[0]) return all([dim % align == 0 for dim in [C, K]]) def select_op( @@ -178,6 +184,9 @@ def select_op( stride, dilation, out_dtype, + data_dtype, + weight_dtype, + use_3xtf32, profile_all=True, use_multiprocessing=False, ): @@ -207,9 +216,9 @@ def select_op( return self.cache[workload] ops = GENERATOR_FUNC_TABLE[self.sm]( - out_dtype, - op_creator=enumerate_conv2d_operators, + out_dtype, data_dtype, weight_dtype, enumerate_conv2d_operators, use_3xtf32 ) + ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops)) if profile_all: @@ -240,6 +249,9 @@ def profile( stride, dilation, out_dtype, + data_dtype, + weight_dtype, + use_3xtf32=True, profile_all=True, use_multiprocessing=False, ): @@ -254,6 +266,9 @@ def profile( stride, dilation, out_dtype, + data_dtype, + weight_dtype, + use_3xtf32, profile_all=profile_all, use_multiprocessing=use_multiprocessing, ) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 9159ed881c745..445acb9305c84 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -125,13 +125,18 @@ def enumerate_gemm_operators( # TODO(masahi): A sensible way to pick reasonable default kernels DEFAULT_KERNELS = { 75: { - "float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1", - "float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1", + ("float16", "float16"): "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1", + ("float16", "float32"): "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1", }, # align1 variants do not seem to be available for sm80 80: { - "float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1", - "float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1", + ("float16", "float16"): "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1", + ("float16", "float32"): "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1", + # two kernels for tf32 and 3xtf32 + ("float32", "float32"): ( + "cutlass_tensorop_s1688gemm_128x64_32x3_tn_align1", + "cutlass_tensorop_s1688gemm_64x64_16x3_tn_align1", + ), }, } @@ -147,21 +152,31 @@ def __init__(self, sm, cutlass_path, binary_path): def check_align(self, op_name, M, N, K): """Filter out kernels that cannot be supported.""" - aligns = re.findall(r"align[1|2|4|8]", op_name) - assert len(aligns) == 1 + match = re.match(".*_align([1-9]+)", op_name) + assert match is not None and len(match.groups()) == 1 # The same alignment is used for all axes - align = int(aligns[0][-1]) + align = int(match.groups()[0]) # TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive. # See https://github.com/NVIDIA/cutlass/issues/362. # When the above issue is resolved, we can remove the alignment check on M below. return all([dim % align == 0 for dim in [M, N, K]]) - def get_default(self, op_type, out_dtype, batched=False): + def get_default( + self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False + ): """Return the default kernel for the requested architecture. For now, the default kernel was picked arbitrary. """ - ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=enumerate_gemm_operators) - default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype] + ops = GENERATOR_FUNC_TABLE[self.sm]( + out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, use_3xtf32 + ) + default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)] + + if arg0_dtype == "float32": + default_kernel_name = ( + default_kernel_name[0] if not use_3xtf32 else default_kernel_name[1] + ) + filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops)) assert len(filtered) == 1 op = filtered[0] @@ -176,7 +191,18 @@ def get_default(self, op_type, out_dtype, batched=False): op.update({"name": name, "opdef": opdef}) return op - def select_op(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False): + def select_op( + self, + M, + N, + K, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + profile_all=True, + use_multiprocessing=False, + ): """ Profile and select the best kernel from candidate kernels. See the documentation for the profile method below. @@ -187,7 +213,10 @@ def select_op(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=Fa ops = GENERATOR_FUNC_TABLE[self.sm]( out_dtype, - op_creator=enumerate_gemm_operators, + arg0_dtype, + arg1_dtype, + enumerate_gemm_operators, + use_3xtf32=use_3xtf32, ) ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops)) @@ -212,6 +241,9 @@ def profile( N, K, out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32=True, profile_all=True, use_multiprocessing=False, batched=False, @@ -221,7 +253,15 @@ def profile( If use_multiprocessing is True, compile all profiler executables in parallel. """ op = self.select_op( - M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing + M, + N, + K, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + profile_all=profile_all, + use_multiprocessing=use_multiprocessing, ) name, opdef = create_gemm_operator_with_epilogue( diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 6632b159febdd..6bb4f290233e4 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -33,6 +33,14 @@ logger = logging.getLogger("cutlass") +dtype_map = { + "int8": DataType.s8, + "uint8": DataType.u8, + "float32": DataType.f32, + "float16": DataType.f16, +} + + def generate_tensor_op_common( math_instructions, alignment_constraints, get_tile_descriptions, op_creator ): @@ -54,45 +62,62 @@ def generate_tensor_op_common( return ops -def generate_sm75_tensor_op_1688(out_dtype, op_creator): +def generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator): """Generate GEMM or Conv2D kernels for Turing.""" - assert out_dtype in ["float32", "float16"] - math_instructions = { - "float32": [ + assert out_dtype in ["float32", "float16", "int32"] + min_cc = 75 + max_cc = 1024 + + if arg0_dtype == "float16" and arg1_dtype == "float16": + math_instructions = [ MathInstruction( [16, 8, 8], DataType.f16, DataType.f16, - DataType.f32, + dtype_map[out_dtype], OpcodeClass.TensorOp, MathOperation.multiply_add, ) - ], - "float16": [ + ] + alignment_constraints = [8, 4, 2, 1] + tile_descriptions = [ + ([256, 128, 32], 2, [4, 2, 1], min_cc, max_cc), + ([128, 256, 32], 2, [2, 4, 1], min_cc, max_cc), + ([128, 128, 32], 2, [2, 2, 1], min_cc, max_cc), + ([64, 128, 32], 2, [2, 2, 1], min_cc, max_cc), + ([128, 64, 32], 2, [2, 2, 1], min_cc, max_cc), + ([64, 64, 32], 2, [2, 2, 1], min_cc, max_cc), + ([64, 128, 64], 2, [1, 2, 2], min_cc, max_cc), + ] + + else: + assert out_dtype == "int32" + math_instructions = [ MathInstruction( - [16, 8, 8], - DataType.f16, - DataType.f16, - DataType.f16, + [8, 8, 16], + dtype_map[arg0_dtype], + dtype_map[arg1_dtype], + DataType.s32, OpcodeClass.TensorOp, - MathOperation.multiply_add, - ) - ], - }[out_dtype] - - alignment_constraints = [8, 4, 2, 1] + MathOperation.multiply_add_saturate, + ), + ] + alignment_constraints = [16, 8, 4, 2, 1] + tile_descriptions = [ + ([256, 128, 64], 2, [4, 2, 1], min_cc, max_cc), + ([128, 256, 64], 2, [2, 4, 1], min_cc, max_cc), + ([128, 128, 64], 2, [2, 2, 1], min_cc, max_cc), + ([64, 256, 64], 2, [1, 4, 1], min_cc, max_cc), + ([256, 64, 64], 2, [4, 1, 1], min_cc, max_cc), + ([64, 128, 64], 2, [2, 2, 1], min_cc, max_cc), + ([128, 64, 64], 2, [2, 2, 1], min_cc, max_cc), + ([64, 64, 64], 2, [2, 2, 1], min_cc, max_cc), + ] def get_tile_descriptions(math_inst): - min_cc = 75 - max_cc = 1024 return [ - TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc), + TileDescription(threadblock_shape, stages, warp_count, math_inst, min_cc, max_cc) + for threadblock_shape, stages, warp_count, min_cc, max_cc in tile_descriptions ] return generate_tensor_op_common( @@ -100,63 +125,117 @@ def get_tile_descriptions(math_inst): ) -def generate_sm80_tensor_op_16816(out_dtype, op_creator): +def generate_sm80_tensor_op_16816(out_dtype, arg0_dtype, arg1_dtype, op_creator, use_3xtf32=True): """Generate GEMM or Conv2D kernels for Ampere.""" - assert out_dtype in ["float32", "float16"] - math_instructions = { - "float32": [ + min_cc = 80 + max_cc = 1024 + max_cc_smem_limited = 80 + + def get_default_tile_descriptions(block_k_factor): + return [ + ([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc), + ([128, 256, int(32 * block_k_factor)], 3, [2, 4, 1], min_cc, max_cc), + ([256, 64, int(32 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc), + ([64, 256, int(32 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc), + ([128, 128, int(32 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), + ([128, 128, int(32 * block_k_factor)], 4, [2, 2, 1], min_cc, max_cc), + ([128, 128, int(32 * block_k_factor)], 5, [2, 2, 1], min_cc, max_cc), + ([128, 64, int(32 * block_k_factor)], 6, [2, 2, 1], min_cc, max_cc), + ([64, 128, int(32 * block_k_factor)], 6, [2, 2, 1], min_cc, max_cc), + ([64, 64, int(32 * block_k_factor)], 10, [2, 2, 1], min_cc, max_cc), + ([256, 128, int(64 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc_smem_limited), + ([128, 256, int(64 * block_k_factor)], 3, [2, 4, 1], min_cc, max_cc_smem_limited), + ([256, 64, int(64 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc_smem_limited), + ([64, 256, int(64 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc_smem_limited), + ([128, 128, int(64 * block_k_factor)], 4, [2, 2, 1], min_cc, max_cc), + ([128, 64, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), + ([64, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), + ([64, 64, int(64 * block_k_factor)], 5, [2, 2, 1], min_cc, max_cc), + ] + + if arg0_dtype == "float16" and arg1_dtype == "float16": + math_instructions = [ MathInstruction( [16, 8, 16], DataType.f16, DataType.f16, - DataType.f32, + dtype_map[out_dtype], OpcodeClass.TensorOp, MathOperation.multiply_add, ) - ], - "float16": [ + ] + alignment_constraints = [8, 4, 2] + tile_descriptions = get_default_tile_descriptions(1) + elif arg0_dtype == "float32" and arg1_dtype == "float32": + math_instructions = [ MathInstruction( - [16, 8, 16], - DataType.f16, - DataType.f16, - DataType.f16, + [16, 8, 8], + DataType.f32, + DataType.f32, + DataType.f32, OpcodeClass.TensorOp, - MathOperation.multiply_add, - ) - ], - }[out_dtype] + MathOperation.multiply_add_fast_f32 if use_3xtf32 else MathOperation.multiply_add, + ), + ] + alignment_constraints = [4, 2, 1] - alignment_constraints = [8, 4, 2] + if use_3xtf32: + tile_descriptions = [ + ([128, 128, 16], 4, [4, 2, 1], min_cc, max_cc), + ([128, 128, 16], 3, [4, 2, 1], min_cc, max_cc), + ([256, 64, 16], 3, [4, 2, 1], min_cc, max_cc), + ([64, 256, 16], 3, [2, 4, 1], min_cc, max_cc), + ([128, 64, 16], 4, [2, 2, 1], min_cc, max_cc), + ([64, 128, 16], 4, [2, 2, 1], min_cc, max_cc), + ([64, 64, 16], 3, [2, 2, 1], min_cc, max_cc), + ([128, 128, 32], 3, [4, 2, 1], min_cc, max_cc), + ([256, 64, 32], 3, [4, 2, 1], min_cc, max_cc_smem_limited), + ([64, 256, 32], 3, [2, 4, 1], min_cc, max_cc_smem_limited), + ([128, 64, 32], 3, [2, 2, 1], min_cc, max_cc), + ([64, 128, 32], 3, [2, 2, 1], min_cc, max_cc), + ([64, 64, 32], 3, [2, 2, 1], min_cc, max_cc), + ] + else: + tile_descriptions = get_default_tile_descriptions(0.5) + else: + assert out_dtype == "int32" + math_instructions = [ + MathInstruction( + [16, 8, 32], + dtype_map[arg0_dtype], + dtype_map[arg1_dtype], + DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_saturate, + ), + ] + alignment_constraints = [16, 8, 4] + tile_descriptions = get_default_tile_descriptions(2) def get_tile_descriptions(math_inst): - min_cc = 80 - max_cc = 1024 - max_cc_smem_limited = 80 return [ - TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription(threadblock_shape, stages, warp_count, math_inst, min_cc, max_cc) + for threadblock_shape, stages, warp_count, min_cc, max_cc in tile_descriptions ] - sm75_kernels = generate_sm75_tensor_op_1688(out_dtype, op_creator) + if arg0_dtype != "float32" and arg1_dtype != "float32": + sm75_kernels = generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator) + else: + # TF32 (float32 + float32 case) is only supported on sm80 + sm75_kernels = [] + sm80_kernels = generate_tensor_op_common( math_instructions, alignment_constraints, get_tile_descriptions, op_creator ) + + # TODO(masahi): For int8 kernels, The CUTLASS generator modifies the output tensor alignment + # after ops are created. Revisit how important this modification is. + # for op in operations: + # if op.tile_description.threadblock_shape[1] >= 128: + # op.C.alignment = 16 + # else: + # op.C.alignment = 8 + return sm75_kernels + sm80_kernels @@ -227,6 +306,9 @@ def evaluate(self, op, args): opath = os.path.join(self.binary_prefix, op_name) if not os.path.exists(opath): self._compile(op) + if not os.path.exists(opath): + # Bail out if compilation fails for a whatever reason (e.g. static assert failure) + return float("inf") cmd = [opath] for arg in args: cmd.append(str(arg)) diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index 08cdb323c126b..5d986f4d03a71 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -28,36 +28,53 @@ class GeneratorTarget(enum.Enum): class DataType(enum.Enum): f16 = enum_auto() f32 = enum_auto() + s8 = enum_auto() + u8 = enum_auto() + s32 = enum_auto() ShortDataTypeNames = { DataType.f16: "h", DataType.f32: "s", + DataType.s32: "i", } DataTypeNames = { DataType.f16: "f16", DataType.f32: "f32", + DataType.s8: "s8", + DataType.u8: "u8", + DataType.s32: "s32", } DataTypeTag = { DataType.f16: "cutlass::half_t", DataType.f32: "float", + DataType.s8: "int8_t", + DataType.s32: "int32_t", + DataType.u8: "uint8_t", } DataTypeSize = { DataType.f16: 16, DataType.f32: 32, + DataType.u8: 8, + DataType.s8: 8, + DataType.s32: 32, } class MathOperation(enum.Enum): multiply_add = enum_auto() + multiply_add_saturate = enum_auto() + multiply_add_fast_f32 = enum_auto() MathOperationTag = { MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd", + MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate", + MathOperation.multiply_add_fast_f32: "cutlass::arch::OpMultiplyAddFastF32", } diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 31f0408c0f042..2cc61923d4b27 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -105,8 +105,11 @@ def make_residual_block_pattern(tensor_op_out, binary_op="add", with_act="relu") def check_dtype(lhs, rhs): """Check if dtypes in the given workload are supported by CUTLASS.""" - # Only fp16 inputs are supported for now. - return lhs.dtype == rhs.dtype and lhs.dtype == "float16" and rhs.dtype == "float16" + return ( + (lhs.dtype == "float16" and rhs.dtype == "float16") + or (lhs.dtype == "float32" and rhs.dtype == "float32") + or (lhs.dtype in ["int8", "uint8"] and rhs.dtype in ["int8", "uint8"]) + ) def get_root_call(call, root_op_name): diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index dc03eea014abe..0a945793b7753 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -43,7 +43,11 @@ namespace contrib { using namespace backend; using Str2StrMap = std::unordered_map; -static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, {"float32", "float"}}; +static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, + {"float32", "float"}, + {"int8", "int8_t"}, + {"uint8", "uint8_t"}, + {"int32", "int32_t"}}; constexpr const char* kAnyDim = "Any"; diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 54738ddd772b6..57f2f39c641b4 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -68,14 +68,16 @@ def get_output_vm(vm, names, inputs): return vm.invoke("main", **params).numpy() -def get_dense_with_shape(data_shape, weight_shape, out_dtype="float16"): - data = relay.var("data", shape=data_shape, dtype="float16") - weight = relay.var("weight", shape=weight_shape, dtype="float16") +def get_dense_with_shape( + data_shape, weight_shape, out_dtype="float16", data_dtype="float16", weight_dtype="float16" +): + data = relay.var("data", shape=data_shape, dtype=data_dtype) + weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype) return relay.nn.dense(data, weight, out_dtype=out_dtype) -def get_dense(M, N, K, out_dtype="float16"): - return get_dense_with_shape((M, K), (N, K), out_dtype) +def get_dense(M, N, K, out_dtype="float16", data_dtype="float16", weight_dtype="float16"): + return get_dense_with_shape((M, K), (N, K), out_dtype, data_dtype, weight_dtype) def get_dense_bias(M, N, K, out_dtype="float16"): @@ -110,9 +112,11 @@ def get_batch_matmul(batch, M, N, K, out_dtype="float16"): return get_batch_matmul_with_shape((batch, M, K), (batch, N, K), out_dtype="float16") -def get_conv2d_nchw(d_shape, w_shape, padding, out_dtype="float16"): - data = relay.var("data", shape=d_shape, dtype="float16") - weight = relay.var("weight", shape=w_shape, dtype="float16") +def get_conv2d_nchw( + d_shape, w_shape, padding, out_dtype="float16", data_dtype="float16", weight_dtype="float16" +): + data = relay.var("data", shape=d_shape, dtype=data_dtype) + weight = relay.var("weight", shape=w_shape, dtype=weight_dtype) out_channel = w_shape[0] return relay.nn.conv2d( data=data, @@ -176,10 +180,17 @@ def get_conv2d_nchw_bias_residual(d_shape, w_shape, padding, out_dtype="float16" return bias_add, data -def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", use_fast_math=False): +def profile_and_build( + mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", use_fast_math=False, use_3xtf32=True +): mod = partition_for_cutlass(mod) mod, num_cutlass_partition = tune_cutlass_kernels( - mod, sm, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir + mod, + sm, + use_3xtf32=use_3xtf32, + profile_all=False, + use_multiprocessing=False, + tmp_dir=tmp_dir, ) with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target="cuda", params=params) @@ -197,9 +208,12 @@ def profile_and_build_vm( lib_path="compile.so", vmcode_path="vmcode.ro", use_fast_math=False, + use_3xtf32=True, ): mod = partition_for_cutlass(mod) - mod, num_cutlass_partition = tune_cutlass_kernels(mod, sm, tmp_dir=tmp_dir) + mod, num_cutlass_partition = tune_cutlass_kernels( + mod, sm, use_3xtf32=use_3xtf32, tmp_dir=tmp_dir + ) with tvm.transform.PassContext(opt_level=3): vm_exec = relay.vm.compile(mod, target="cuda", params=params) vm_exec = build_cutlass_kernels_vm( @@ -210,7 +224,18 @@ def profile_and_build_vm( def verify_dense( - func, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False + func, + M, + N, + K, + ref_target="cuda", + sm=80, + atol=1e-5, + rtol=1e-5, + run_benchmark=False, + data_dtype="float16", + weight_dtype="float16", + use_3xtf32=True, ): if not has_cutlass(): return @@ -218,9 +243,9 @@ def verify_dense( typ = relay.transform.InferType()(mod)["main"].body.checked_type out_dtype = typ.dtype use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape) - np_data = np.random.uniform(-1, 1, (M, K)).astype("float16") - np_weight = np.random.uniform(-1, 1, (N, K)).astype("float16") - np_bias = np.random.uniform(-1, 1, (N,)).astype(out_dtype) + np_data = get_random_ndarray((M, K), data_dtype) + np_weight = get_random_ndarray((N, K), weight_dtype) + np_bias = get_random_ndarray((N,), out_dtype) params = {"weight": np_weight, "bias": np_bias} @@ -235,7 +260,9 @@ def verify_dense( ) return else: - rt_mod, dev, num_partition = profile_and_build_vm(mod, params, sm) + rt_mod, dev, num_partition = profile_and_build_vm( + mod, params, sm, use_3xtf32=use_3xtf32 + ) rt_mod_ref, dev = get_ref_vm(mod, params, target=ref_target) x = tvm.nd.array(np_data, device=dev) @@ -243,7 +270,7 @@ def verify_dense( ref_out = get_output_vm(rt_mod_ref, ["data"], [x]) else: rt_mod_ref, dev = get_ref_rt_mod(mod, params, target=ref_target) - rt_mod, dev, num_partition = profile_and_build(mod, params, sm) + rt_mod, dev, num_partition = profile_and_build(mod, params, sm, use_3xtf32=use_3xtf32) x = tvm.nd.array(np_data, device=dev) out = get_output(rt_mod, ["data"], [x]) ref_out = get_output(rt_mod_ref, ["data"], [x]) @@ -302,6 +329,33 @@ def test_dense(): verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K) # Test align1 case verify_dense(get_dense_bias(M, N + 1, K), M, N + 1, K) + # int8 + verify_dense( + get_dense(M, N, K, "int32", "int8", "int8"), M, N, K, data_dtype="int8", weight_dtype="int8" + ) + + dense_fp32 = get_dense(M, N, K, "float32", "float32", "float32") + # tf32 + verify_dense( + dense_fp32, + M, + N, + K, + data_dtype="float32", + weight_dtype="float32", + use_3xtf32=False, + atol=1e-2, + rtol=1e-2, + ) + # 3xtf32 + verify_dense( + dense_fp32, + M, + N, + K, + data_dtype="float32", + weight_dtype="float32", + ) def test_dense_bias(): @@ -371,6 +425,14 @@ def convert_conv2d_layout(mod, desired_layouts): return seq(mod) +def get_random_ndarray(shape, dtype): + if dtype == "int8": + return np.random.randint(-128, 128, shape).astype(dtype) + elif dtype == "uint8": + return np.random.randint(0, 256, shape).astype(dtype) + return np.random.uniform(-1, 1, shape).astype(dtype) + + def verify_conv2d( expr_nchw, # can be dynamic batch expr_ref, # always static batch @@ -382,6 +444,9 @@ def verify_conv2d( use_cudnn_ref=False, run_benchmark=False, use_fast_math=False, + data_dtype="float16", + weight_dtype="float16", + ref_target="cuda", ): if not has_cutlass(): return @@ -392,9 +457,9 @@ def verify_conv2d( typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type out_dtype = typ.dtype - np_data = np.random.uniform(-1, 1, d_shape).astype("float16") - np_weight = np.random.uniform(-1, 1, w_shape).astype("float16") - np_bias = np.random.uniform(-1, 1, (w_shape[0],)).astype(out_dtype) + np_data = get_random_ndarray(d_shape, data_dtype) + np_weight = get_random_ndarray(w_shape, weight_dtype) + np_bias = get_random_ndarray((w_shape[0],), out_dtype) params = {"weight": np_weight, "bias": np_bias} @@ -426,7 +491,7 @@ def verify_conv2d( rt_mod_ref, dev = get_ref_rt_mod( convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}), params, - target="cuda", + target=ref_target, ) ref_out = get_output(rt_mod_ref, ["data"], [np_data]) @@ -469,6 +534,34 @@ def test_conv2d(): mod_dyn, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False ) + for data_dtype, weight_dtype, out_dtype in [ + ("float32", "float32", "float32"), # 3xtf32 + ("int8", "int8", "int32"), + ("uint8", "int8", "int32"), + ]: + expr = get_conv2d_nchw( + d_shape, + w_shape, + padding, + out_dtype=out_dtype, + data_dtype=data_dtype, + weight_dtype=weight_dtype, + ) + + verify_conv2d( + expr, + expr, + d_shape, + w_shape, + sm=80, + atol=1e-5, + rtol=1e-5, + run_benchmark=False, + data_dtype=data_dtype, + weight_dtype=weight_dtype, + ref_target="llvm", + ) + def test_conv2d_fusion(): d_shape = (16, 16, 32, 32)