diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index f7317646265b1..ee9017439d8f9 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -18,6 +18,12 @@ """Conv2d kernel generator and profiler for CUTLASS.""" from .conv2d_operation import Conv2dOperation, EmitConv2dInstance from .conv2d_profiler import Conv2dProfilerEmitter +from .gemm_profiler import GemmProfilerEmitter +from gen_tensor_op import ( + ProfilerEngine, + generate_sm75_tensor_op_1688, + generate_sm80_tensor_op_16816, +) from .library import ( EpilogueFunctor, SwizzlingFunctor, @@ -36,7 +42,6 @@ def create_conv2d_operator( - layout, tile_descriptions, data_type, alignment_constraints, @@ -51,6 +56,7 @@ def create_conv2d_operator( element_a, element_b, element_c, element_epilogue = data_type iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] + layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) for tile in tile_descriptions: for alignment in alignment_constraints: @@ -105,3 +111,81 @@ def create_conv2d_operator( ret.append(op_entry) return ret + + +GENERATOR_FUNC_TABLE = { + 75: generate_sm75_tensor_op_1688, + 80: generate_sm80_tensor_op_16816, +} + + +# TODO(masahi): A sensible way to pick reasonable default kernels +DEFAULT_KERNELS = { + 75: { + "float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align4", + "float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align4", + }, + 80: { + "float16": "cutlass_tensorop_h16816gemm_128x256_32x3_tn_align4", + "float32": "cutlass_tensorop_s16816gemm_f16_128x128_32x3_tn_align4", + }, +} + + +class CutlassConv2DProfiler: + """Profile all candidate kernels and select the best one.""" + + def __init__(self, sm, cutlass_path, binary_path): + assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, "sm%d not supported yet." % sm + self.engine = ProfilerEngine(sm, cutlass_path, binary_path) + self.sm = sm + self.cache = {} + + def check_align(self, op_name, M): + """Filter out kernels that cannot be supported.""" + # TODO + return True + + def get_default(self, out_dtype): + """Return the default kernel for the requested architecture. + For now, the default kernel was picked arbitrary. + """ + ops = GENERATOR_FUNC_TABLE[self.sm]( + out_dtype, op_creator=create_conv2d_operator + ) + default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype] + filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops)) + assert len(filtered) == 1 + return filtered[0] + + def profile( + self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False + ): + """Profile and select the best kernel from candidate kernels. + If profile_all is False, return immediately after the first applicable kernel is found. + If use_multiprocessing is True, compile all profiler executables in parallel. + """ + if (M, N, K) in self.cache: + return self.cache[(M, N, K)] + + ops = GENERATOR_FUNC_TABLE[self.sm]( + out_dtype, op_creator=create_conv2d_operator + ) + ops = list(filter(lambda op: self.check_align(op["name"], M), ops)) + + for op in ops: + op["runtime"] = -1 + + if profile_all: + self.engine.compile_all(ops, use_multiprocessing) + + for op in ops: + out = self.engine.evaluate(op, [M, N, K]) + op["runtime"] = out + if out > 0 and profile_all is False: + break + + valid_ops = filter(lambda op: op["runtime"] > 0, ops) + output = sorted(valid_ops, key=lambda i: i["runtime"]) + self.cache[(M, N, K)] = output[0] + return output[0] diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 525e06d447df6..9dee41c82794e 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -43,7 +43,6 @@ def _create_gemm_operator( - layouts, tile_descriptions, data_type, alignment_constraints, @@ -60,6 +59,10 @@ def _create_gemm_operator( if batched: swizzling_functor = SwizzlingFunctor.Batched + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + for layout in layouts: for tile_description in tile_descriptions: for alignment in alignment_constraints: @@ -135,15 +138,14 @@ def _create_gemm_operator( def create_gemm_operator(batched): + # TODO: replace with partial def op_creator( - layouts, tile_descriptions, data_type, alignment_constraints, swizzling_functor=SwizzlingFunctor.Identity8, ): return _create_gemm_operator( - layouts, tile_descriptions, data_type, alignment_constraints, diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 20d7637679847..1fee256a731dd 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -43,9 +43,6 @@ def generate_tensor_op_common( ): """Common kernel generator to be used by archtecture specific generators.""" ops = [] - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), - ] for math_inst in math_instructions: tile_descriptions = get_tile_descriptions(math_inst) data_type = [ @@ -55,7 +52,7 @@ def generate_tensor_op_common( math_inst.element_accumulator, ] - out = op_creator(layouts, tile_descriptions, data_type, alignment_constraints) + out = op_creator(tile_descriptions, data_type, alignment_constraints) ops.extend(out)