diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py index 27e2137d615d2..af81a7e864026 100644 --- a/python/tvm/contrib/cutlass/conv2d_operation.py +++ b/python/tvm/contrib/cutlass/conv2d_operation.py @@ -55,24 +55,11 @@ def __init__( self.stride_support = stride_support self.swizzling_functor = swizzling_functor - # - def is_complex(self): - complex_operators = [ - MathOperation.multiply_add_complex, - MathOperation.multiply_add_complex_gaussian, - ] - return self.tile_description.math_instruction.math_operation in complex_operators - # def accumulator_type(self): - accum = self.tile_description.math_instruction.element_accumulator + return self.tile_description.math_instruction.element_accumulator - if self.is_complex(): - return get_complex_from_real(accum) - return accum - - # def core_name(self): """ The basic operation kind is prefixed with a letter indicating the accumulation type. """ @@ -112,7 +99,7 @@ def extended_name(self): else: extended_name = "${core_name}" - extended_name = SubstituteTemplate( + extended_name = substitute_template( extended_name, { "element_a": DataTypeNames[self.A.element], @@ -145,7 +132,7 @@ def configuration_name(self): else: configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}" - return SubstituteTemplate( + return substitute_template( configuration_name, { "opcode_class": opcode_class_name, @@ -258,14 +245,12 @@ def emit(self, operation): operation.iterator_algorithm ].capitalize(), "stride_support": StrideSupportTag[operation.stride_support], - "math_operator": "cutlass::arch::OpMultiplyAddComplex" - if operation.is_complex() - else MathOperationTag[operation.tile_description.math_instruction.math_operation], + "math_operator": MathOperationTag[operation.tile_description.math_instruction.math_operation], "align_a": str(operation.A.alignment), "align_b": str(operation.B.alignment), } - return SubstituteTemplate(self.template, values) + return substitute_template(self.template, values) class EmitConv2dConfigurationLibrary: @@ -340,7 +325,7 @@ def __init__(self, operation_path, configuration_name): def __enter__(self): self.configuration_file = open(self.configuration_path, "w") self.configuration_file.write( - SubstituteTemplate( + substitute_template( self.header_template, {"configuration_name": self.configuration_name} ) ) @@ -351,7 +336,7 @@ def __enter__(self): def emit(self, operation): self.operations.append(operation) self.configuration_file.write( - SubstituteTemplate( + substitute_template( self.instance_template, { "configuration_name": self.configuration_name, @@ -365,14 +350,14 @@ def emit(self, operation): def __exit__(self, exception_type, exception_value, traceback): self.configuration_file.write( - SubstituteTemplate( + substitute_template( self.configuration_header, {"configuration_name": self.configuration_name} ) ) for operation in self.operations: self.configuration_file.write( - SubstituteTemplate( + substitute_template( self.configuration_instance, { "configuration_name": self.configuration_name, diff --git a/python/tvm/contrib/cutlass/conv2d_profiler.py b/python/tvm/contrib/cutlass/conv2d_profiler.py new file mode 100644 index 0000000000000..eccb13729982a --- /dev/null +++ b/python/tvm/contrib/cutlass/conv2d_profiler.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-outside-toplevel, invalid-name +"""Instantiate a C++ source for profiling CUTLASS kernels.""" +from .gemm_profiler import GemmProfilerEmitter + + +class Conv2dProfilerEmitter: + def __init__(self): + self.gemm_profiler_emitter = GemmProfilerEmitter() + + def emit(self, op_name, op_def, dtype_a, dtype_b, dtype_c, ld): + return self.gemm_profiler_emitter(op_name, op_def, dtype_a, dtype_b, dtype_c, ld) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py new file mode 100644 index 0000000000000..f7317646265b1 --- /dev/null +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Conv2d kernel generator and profiler for CUTLASS.""" +from .conv2d_operation import Conv2dOperation, EmitConv2dInstance +from .conv2d_profiler import Conv2dProfilerEmitter +from .library import ( + EpilogueFunctor, + SwizzlingFunctor, + TensorDescription, + DataTypeTag, + LayoutType, + MathInstruction, + DataType, + OpcodeClass, + MathOperation, + TileDescription, + ConvKind, + IteratorAlgorithm, + StrideSupport, +) + + +def create_conv2d_operator( + layout, + tile_descriptions, + data_type, + alignment_constraints, + swizzling_functor=SwizzlingFunctor.Identity4, +): + """Exhaustively instantiate all kernels from a given configuration.""" + ret = [] + + profiler_emitter = Conv2dProfilerEmitter() + kernel_emitter = EmitConv2dInstance() + + element_a, element_b, element_c, element_epilogue = data_type + iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] + + for tile in tile_descriptions: + for alignment in alignment_constraints: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + for iterator_algorithm in iterator_algorithms: + op_entry = {} + + for epilogue, opdef in zip( + [ + EpilogueFunctor.LinearCombination, + EpilogueFunctor.LinearCombinationBias, + EpilogueFunctor.LinearCombinationRelu, + ], + ["opdef", "opdef_bias", "opdef_bias_relu"], + ): + op = Conv2dOperation( + ConvKind.Fprop, + iterator_algorithm, + tile.minimum_compute_capability, + tile, + A, + B, + C, + element_epilogue, + StrideSupport.Strided, + epilogue, + swizzling_functor_, + ) + + op_entry[opdef] = kernel_emitter.emit(op) + + op = op_entry["opdef"] + op_entry["op"] = op + op_entry["name"] = op.procedural_name() + op_entry["src"] = profiler_emitter.emit( + op.procedural_name(), + op, + DataTypeTag[element_a], + DataTypeTag[element_b], + DataTypeTag[element_c], + op.leading_dim(), + ) + op_entry["runtime"] = 9999999 + + ret.append(op_entry) + + return ret diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 1ed4bfe5fc4ca..fa58a4ca10443 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name -"""Kernel generator and profiler for CUTLASS.""" +"""GEMM kernel generator and profiler for CUTLASS.""" import logging import os import re @@ -45,7 +45,6 @@ def create_gemm_operator( tile_descriptions, data_type, alignment_constraints, - epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=SwizzlingFunctor.Identity8, batched=False, ): @@ -76,7 +75,7 @@ def create_gemm_operator( B, C, element_epilogue, - epilogue_functor, + EpilogueFunctor.LinearCombination, swizzling_functor, ) op_bias = GemmOperation( @@ -110,7 +109,6 @@ def create_gemm_operator( swizzling_functor, ) - kernel_emitter = EmitGemmInstance() op_entry["op"] = op op_entry["name"] = op.procedural_name() op_entry["opdef"] = kernel_emitter.emit(op, batched=batched) @@ -342,7 +340,7 @@ def evaluate(self, op, args): return rt -class CutlassGemmProfiler(object): +class CutlassGemmProfiler: """Profile all candidate kernels and select the best one.""" def __init__(self, sm, cutlass_path, binary_path): diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index ba797efdec49e..820ebc4ebc47e 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -192,6 +192,22 @@ class ConvKind(enum.Enum): } +class StrideSupport(enum.Enum): + Strided = enum_auto() + Unity = enum_auto() + + +StrideSupportTag = { + StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', + StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', +} + +StrideSupportNames = { + StrideSupport.Strided: '', + StrideSupport.Unity: 'unity_stride', +} + + class IteratorAlgorithm(enum.Enum): Analytic = enum_auto() Optimized = enum_auto()