From adf560ebed8465c22bf58f406d0a8d20663cdd1d Mon Sep 17 00:00:00 2001 From: masahi Date: Fri, 26 Nov 2021 08:53:55 +0900 Subject: [PATCH] [CUTLASS] Refactor GEMM generator in preparation for conv2d (#9571) * split non-gemm specific generator code to gen_tensor_op.py commit 250f915652e72e0012e9aa6ce0b6ef337d3da845 Author: Masahiro Masuda Date: Sun Nov 14 06:44:52 2021 +0900 remove conv2d stuff commit 1a6b27c438472f13acd4a0f466d78f293415e076 Author: Masahiro Masuda Date: Sun Nov 14 06:41:31 2021 +0900 remove unused import commit f7c3b5a191b8c73e8b178c32f6d3182fb0f697d6 Author: Masahiro Masuda Date: Sun Nov 14 06:37:07 2021 +0900 add profiler boilarplate for conv2d commit ca1ae274fb8f96a1dcde688deaf15339fe5604fb Author: Masahiro Masuda Date: Sun Nov 14 06:22:06 2021 +0900 introduce gen_tensor_op.py commit 37bb918e0873f04457c29479eb21a530b7052217 Author: Masahiro Masuda Date: Sun Nov 14 05:45:41 2021 +0900 more conv2d code commit 5c00398892c99cb2a03be51f75878992663432dd Author: Masahiro Masuda Date: Sun Nov 14 05:13:30 2021 +0900 Begin conv2d support * fix * use functools.partial * remove unused import --- python/tvm/contrib/cutlass/gen_gemm.py | 230 ++------------------ python/tvm/contrib/cutlass/gen_tensor_op.py | 217 ++++++++++++++++++ tests/python/contrib/test_cutlass.py | 2 +- 3 files changed, 238 insertions(+), 211 deletions(-) create mode 100644 python/tvm/contrib/cutlass/gen_tensor_op.py diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 1ed4bfe5fc4c..4025354dc739 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -15,37 +15,29 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name -"""Kernel generator and profiler for CUTLASS.""" -import logging -import os +"""GEMM kernel generator and profiler for CUTLASS.""" +from functools import partial import re -import tempfile -import subprocess -import multiprocessing from .gemm_operation import GemmOperation, EmitGemmInstance from .gemm_profiler import GemmProfilerEmitter +from .gen_tensor_op import ( + ProfilerEngine, + generate_sm75_tensor_op_1688, + generate_sm80_tensor_op_16816, +) from .library import ( EpilogueFunctor, SwizzlingFunctor, TensorDescription, DataTypeTag, LayoutType, - MathInstruction, - DataType, - OpcodeClass, - MathOperation, - TileDescription, ) -logger = logging.getLogger("cutlass") - def create_gemm_operator( - layouts, tile_descriptions, data_type, alignment_constraints, - epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=SwizzlingFunctor.Identity8, batched=False, ): @@ -59,6 +51,10 @@ def create_gemm_operator( if batched: swizzling_functor = SwizzlingFunctor.Batched + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + for layout in layouts: for tile_description in tile_descriptions: for alignment in alignment_constraints: @@ -76,7 +72,7 @@ def create_gemm_operator( B, C, element_epilogue, - epilogue_functor, + EpilogueFunctor.LinearCombination, swizzling_functor, ) op_bias = GemmOperation( @@ -110,7 +106,6 @@ def create_gemm_operator( swizzling_functor, ) - kernel_emitter = EmitGemmInstance() op_entry["op"] = op op_entry["name"] = op.procedural_name() op_entry["opdef"] = kernel_emitter.emit(op, batched=batched) @@ -134,141 +129,12 @@ def create_gemm_operator( return ret -def generate_tensor_op_common( - math_instructions, alignment_constraints, get_tile_descriptions, batched=False -): - """Common kernel generator to be used by archtecture specific generators.""" - ops = [] - layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), - ] - for math_inst in math_instructions: - tile_descriptions = get_tile_descriptions(math_inst) - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - out = create_gemm_operator( - layouts, tile_descriptions, data_type, alignment_constraints, batched=batched - ) - - ops.extend(out) - - return ops - - -def generate_sm75_tensor_op_1688(out_dtype, batched=False): - """Generate GEMM kernels for Turing.""" - assert out_dtype in ["float32", "float16"] - math_instructions = { - "float32": [ - MathInstruction( - [16, 8, 8], - DataType.f16, - DataType.f16, - DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add, - ) - ], - "float16": [ - MathInstruction( - [16, 8, 8], - DataType.f16, - DataType.f16, - DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add, - ) - ], - }[out_dtype] - - alignment_constraints = [8, 4, 2, 1] - - def get_tile_descriptions(math_inst): - min_cc = 75 - max_cc = 1024 - return [ - TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc), - ] - - return generate_tensor_op_common( - math_instructions, alignment_constraints, get_tile_descriptions, batched - ) - - -def generate_sm80_tensor_op_16816(out_dtype, batched=False): - """Generate GEMM kernels for Ampere.""" - assert out_dtype in ["float32", "float16"] - math_instructions = { - "float32": [ - MathInstruction( - [16, 8, 16], - DataType.f16, - DataType.f16, - DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add, - ) - ], - "float16": [ - MathInstruction( - [16, 8, 16], - DataType.f16, - DataType.f16, - DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add, - ) - ], - }[out_dtype] - - alignment_constraints = [8, 4, 2] - - def get_tile_descriptions(math_inst): - min_cc = 80 - max_cc = 1024 - max_cc_smem_limited = 80 - return [ - TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), - ] - - return generate_tensor_op_common( - math_instructions, alignment_constraints, get_tile_descriptions, batched - ) - - GENERATOR_FUNC_TABLE = { 75: generate_sm75_tensor_op_1688, 80: generate_sm80_tensor_op_16816, } + # TODO(masahi): A sensible way to pick reasonable default kernels DEFAULT_KERNELS = { 75: { @@ -282,67 +148,7 @@ def get_tile_descriptions(math_inst): } -class ProfilerEngine: - """Compile and run a given profiler executable.""" - - def __init__(self, cuda_arch, cutlass_path, binary_prefix): - self.cuda_arch = cuda_arch - self.binary_prefix = binary_prefix - self.cutlass = cutlass_path - self.cflags = "-I{cutlass}/include -I{cutlass}/tools/util/include -O3 -std=c++11".format( - cutlass=cutlass_path - ) - self.cflags += " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1" - self.cflags += " -gencode=arch=compute_{arch},code=[sm_{arch},compute_{arch}]".format( - arch=cuda_arch - ) - self.cflags += " -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing" - self.cmd = "nvcc {cflags} {src} -o {output}" - - def _compile(self, op): - os.makedirs(self.binary_prefix, exist_ok=True) - opath = os.path.join(self.binary_prefix, op["name"]) - if os.path.exists(opath): - return - fi = tempfile.NamedTemporaryFile("w", delete=False, suffix=".cu") - fi.write(op["src"]) - fi.close() - cmd = self.cmd.format(cflags=self.cflags, src=fi.name, output=opath) - os.system(cmd) - os.unlink(fi.name) - - def compile_all(self, ops, use_multiprocessing=False): - """Compile all profiler executables.""" - if use_multiprocessing: - pool = multiprocessing.Pool(multiprocessing.cpu_count()) - pool.map(self._compile, ops) - else: - for op in ops: - self._compile(op) - - def evaluate(self, op, args): - """Run the profiler executable corresponding to op_name with args.""" - op_name = op["name"] - opath = os.path.join(self.binary_prefix, op_name) - if not os.path.exists(opath): - self._compile(op) - cmd = [opath] - if args is not None: - cmd.append(str(args[0])) - cmd.append(str(args[1])) - cmd.append(str(args[2])) - if len(args) > 3: - cmd.append(str(args[3])) - try: - sp = subprocess.run(cmd, capture_output=True, check=True) - rt = float(sp.stdout) - logger.info("%s, %f", op_name, rt) - except subprocess.CalledProcessError: - rt = -1 - return rt - - -class CutlassGemmProfiler(object): +class CutlassGemmProfiler: """Profile all candidate kernels and select the best one.""" def __init__(self, sm, cutlass_path, binary_path): @@ -364,7 +170,9 @@ def get_default(self, out_dtype, batched=False): """Return the default kernel for the requested architecture. For now, the default kernel was picked arbitrary. """ - ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched) + ops = GENERATOR_FUNC_TABLE[self.sm]( + out_dtype, op_creator=partial(create_gemm_operator, batched=batched) + ) default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype] filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops)) assert len(filtered) == 1 @@ -380,7 +188,9 @@ def profile( if (M, N, K) in self.cache: return self.cache[(M, N, K)] - ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched) + ops = GENERATOR_FUNC_TABLE[self.sm]( + out_dtype, op_creator=partial(create_gemm_operator, batched=batched) + ) ops = list(filter(lambda op: self.check_align(op["name"], M), ops)) for op in ops: diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py new file mode 100644 index 000000000000..c8221514ce0a --- /dev/null +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -0,0 +1,217 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Common functions and classes for CUTLASS GEMM and Conv2d geneator.""" +import logging +import os +import tempfile +import subprocess +import multiprocessing +from .library import ( + MathInstruction, + DataType, + OpcodeClass, + MathOperation, + TileDescription, +) + +logger = logging.getLogger("cutlass") + + +def generate_tensor_op_common( + math_instructions, alignment_constraints, get_tile_descriptions, op_creator +): + """Common kernel generator to be used by archtecture specific generators.""" + ops = [] + for math_inst in math_instructions: + tile_descriptions = get_tile_descriptions(math_inst) + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + out = op_creator(tile_descriptions, data_type, alignment_constraints) + + ops.extend(out) + + return ops + + +def generate_sm75_tensor_op_1688(out_dtype, op_creator): + """Generate GEMM or Conv2D kernels for Turing.""" + assert out_dtype in ["float32", "float16"] + math_instructions = { + "float32": [ + MathInstruction( + [16, 8, 8], + DataType.f16, + DataType.f16, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + "float16": [ + MathInstruction( + [16, 8, 8], + DataType.f16, + DataType.f16, + DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + }[out_dtype] + + alignment_constraints = [8, 4, 2, 1] + + def get_tile_descriptions(math_inst): + min_cc = 75 + max_cc = 1024 + return [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc), + ] + + return generate_tensor_op_common( + math_instructions, alignment_constraints, get_tile_descriptions, op_creator + ) + + +def generate_sm80_tensor_op_16816(out_dtype, op_creator): + """Generate GEMM or Conv2D kernels for Ampere.""" + assert out_dtype in ["float32", "float16"] + math_instructions = { + "float32": [ + MathInstruction( + [16, 8, 16], + DataType.f16, + DataType.f16, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + "float16": [ + MathInstruction( + [16, 8, 16], + DataType.f16, + DataType.f16, + DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + }[out_dtype] + + alignment_constraints = [8, 4, 2] + + def get_tile_descriptions(math_inst): + min_cc = 80 + max_cc = 1024 + max_cc_smem_limited = 80 + return [ + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + return generate_tensor_op_common( + math_instructions, alignment_constraints, get_tile_descriptions, op_creator + ) + + +class ProfilerEngine: + """Compile and run a given profiler executable.""" + + def __init__(self, cuda_arch, cutlass_path, binary_prefix): + self.cuda_arch = cuda_arch + self.binary_prefix = binary_prefix + self.cutlass = cutlass_path + self.cflags = "-I{cutlass}/include -I{cutlass}/tools/util/include -O3 -std=c++11".format( + cutlass=cutlass_path + ) + self.cflags += " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1" + self.cflags += " -gencode=arch=compute_{arch},code=[sm_{arch},compute_{arch}]".format( + arch=cuda_arch + ) + self.cflags += " -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing" + self.cmd = "nvcc {cflags} {src} -o {output}" + + def _compile(self, op): + os.makedirs(self.binary_prefix, exist_ok=True) + opath = os.path.join(self.binary_prefix, op["name"]) + if os.path.exists(opath): + return + fi = tempfile.NamedTemporaryFile("w", delete=False, suffix=".cu") + fi.write(op["src"]) + fi.close() + cmd = self.cmd.format(cflags=self.cflags, src=fi.name, output=opath) + os.system(cmd) + os.unlink(fi.name) + + def compile_all(self, ops, use_multiprocessing=False): + """Compile all profiler executables.""" + if use_multiprocessing: + pool = multiprocessing.Pool(multiprocessing.cpu_count()) + pool.map(self._compile, ops) + else: + for op in ops: + self._compile(op) + + def evaluate(self, op, args): + """Run the profiler executable corresponding to op_name with args.""" + op_name = op["name"] + opath = os.path.join(self.binary_prefix, op_name) + if not os.path.exists(opath): + self._compile(op) + cmd = [opath] + if args is not None: + cmd.append(str(args[0])) + cmd.append(str(args[1])) + cmd.append(str(args[2])) + if len(args) > 3: + cmd.append(str(args[3])) + try: + sp = subprocess.run(cmd, capture_output=True, check=True) + rt = float(sp.stdout) + logger.info("%s, %f", op_name, rt) + except subprocess.CalledProcessError: + rt = -1 + return rt diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 5a1ff8b2c17d..6f27d57d95d7 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -213,7 +213,7 @@ def verify_batch_matmul( np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol) - if True: + if run_benchmark: print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600)) print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600))