Skip to content

Commit

Permalink
more conv2d code
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Nov 13, 2021
1 parent 5c00398 commit 37bb918
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 29 deletions.
33 changes: 9 additions & 24 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,11 @@ def __init__(
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor

#
def is_complex(self):
complex_operators = [
MathOperation.multiply_add_complex,
MathOperation.multiply_add_complex_gaussian,
]
return self.tile_description.math_instruction.math_operation in complex_operators

#
def accumulator_type(self):
accum = self.tile_description.math_instruction.element_accumulator
return self.tile_description.math_instruction.element_accumulator

if self.is_complex():
return get_complex_from_real(accum)

return accum

#
def core_name(self):
""" The basic operation kind is prefixed with a letter indicating the accumulation type. """

Expand Down Expand Up @@ -112,7 +99,7 @@ def extended_name(self):
else:
extended_name = "${core_name}"

extended_name = SubstituteTemplate(
extended_name = substitute_template(
extended_name,
{
"element_a": DataTypeNames[self.A.element],
Expand Down Expand Up @@ -145,7 +132,7 @@ def configuration_name(self):
else:
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}"

return SubstituteTemplate(
return substitute_template(
configuration_name,
{
"opcode_class": opcode_class_name,
Expand Down Expand Up @@ -258,14 +245,12 @@ def emit(self, operation):
operation.iterator_algorithm
].capitalize(),
"stride_support": StrideSupportTag[operation.stride_support],
"math_operator": "cutlass::arch::OpMultiplyAddComplex"
if operation.is_complex()
else MathOperationTag[operation.tile_description.math_instruction.math_operation],
"math_operator": MathOperationTag[operation.tile_description.math_instruction.math_operation],
"align_a": str(operation.A.alignment),
"align_b": str(operation.B.alignment),
}

return SubstituteTemplate(self.template, values)
return substitute_template(self.template, values)


class EmitConv2dConfigurationLibrary:
Expand Down Expand Up @@ -340,7 +325,7 @@ def __init__(self, operation_path, configuration_name):
def __enter__(self):
self.configuration_file = open(self.configuration_path, "w")
self.configuration_file.write(
SubstituteTemplate(
substitute_template(
self.header_template, {"configuration_name": self.configuration_name}
)
)
Expand All @@ -351,7 +336,7 @@ def __enter__(self):
def emit(self, operation):
self.operations.append(operation)
self.configuration_file.write(
SubstituteTemplate(
substitute_template(
self.instance_template,
{
"configuration_name": self.configuration_name,
Expand All @@ -365,14 +350,14 @@ def emit(self, operation):
def __exit__(self, exception_type, exception_value, traceback):

self.configuration_file.write(
SubstituteTemplate(
substitute_template(
self.configuration_header, {"configuration_name": self.configuration_name}
)
)

for operation in self.operations:
self.configuration_file.write(
SubstituteTemplate(
substitute_template(
self.configuration_instance,
{
"configuration_name": self.configuration_name,
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/contrib/cutlass/conv2d_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel, invalid-name
"""Instantiate a C++ source for profiling CUTLASS kernels."""
from .gemm_profiler import GemmProfilerEmitter


class Conv2dProfilerEmitter:
def __init__(self):
self.gemm_profiler_emitter = GemmProfilerEmitter()

def emit(self, op_name, op_def, dtype_a, dtype_b, dtype_c, ld):
return self.gemm_profiler_emitter(op_name, op_def, dtype_a, dtype_b, dtype_c, ld)
107 changes: 107 additions & 0 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Conv2d kernel generator and profiler for CUTLASS."""
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
from .conv2d_profiler import Conv2dProfilerEmitter
from .library import (
EpilogueFunctor,
SwizzlingFunctor,
TensorDescription,
DataTypeTag,
LayoutType,
MathInstruction,
DataType,
OpcodeClass,
MathOperation,
TileDescription,
ConvKind,
IteratorAlgorithm,
StrideSupport,
)


def create_conv2d_operator(
layout,
tile_descriptions,
data_type,
alignment_constraints,
swizzling_functor=SwizzlingFunctor.Identity4,
):
"""Exhaustively instantiate all kernels from a given configuration."""
ret = []

profiler_emitter = Conv2dProfilerEmitter()
kernel_emitter = EmitConv2dInstance()

element_a, element_b, element_c, element_epilogue = data_type
iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]

for tile in tile_descriptions:
for alignment in alignment_constraints:

alignment_c = min(8, alignment)

A = TensorDescription(element_a, layout[0], alignment)
B = TensorDescription(element_b, layout[1], alignment)
C = TensorDescription(element_c, layout[2], alignment_c)

swizzling_functor_ = swizzling_functor

for iterator_algorithm in iterator_algorithms:
op_entry = {}

for epilogue, opdef in zip(
[
EpilogueFunctor.LinearCombination,
EpilogueFunctor.LinearCombinationBias,
EpilogueFunctor.LinearCombinationRelu,
],
["opdef", "opdef_bias", "opdef_bias_relu"],
):
op = Conv2dOperation(
ConvKind.Fprop,
iterator_algorithm,
tile.minimum_compute_capability,
tile,
A,
B,
C,
element_epilogue,
StrideSupport.Strided,
epilogue,
swizzling_functor_,
)

op_entry[opdef] = kernel_emitter.emit(op)

op = op_entry["opdef"]
op_entry["op"] = op
op_entry["name"] = op.procedural_name()
op_entry["src"] = profiler_emitter.emit(
op.procedural_name(),
op,
DataTypeTag[element_a],
DataTypeTag[element_b],
DataTypeTag[element_c],
op.leading_dim(),
)
op_entry["runtime"] = 9999999

ret.append(op_entry)

return ret
8 changes: 3 additions & 5 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Kernel generator and profiler for CUTLASS."""
"""GEMM kernel generator and profiler for CUTLASS."""
import logging
import os
import re
Expand Down Expand Up @@ -45,7 +45,6 @@ def create_gemm_operator(
tile_descriptions,
data_type,
alignment_constraints,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity8,
batched=False,
):
Expand Down Expand Up @@ -76,7 +75,7 @@ def create_gemm_operator(
B,
C,
element_epilogue,
epilogue_functor,
EpilogueFunctor.LinearCombination,
swizzling_functor,
)
op_bias = GemmOperation(
Expand Down Expand Up @@ -110,7 +109,6 @@ def create_gemm_operator(
swizzling_functor,
)

kernel_emitter = EmitGemmInstance()
op_entry["op"] = op
op_entry["name"] = op.procedural_name()
op_entry["opdef"] = kernel_emitter.emit(op, batched=batched)
Expand Down Expand Up @@ -342,7 +340,7 @@ def evaluate(self, op, args):
return rt


class CutlassGemmProfiler(object):
class CutlassGemmProfiler:
"""Profile all candidate kernels and select the best one."""

def __init__(self, sm, cutlass_path, binary_path):
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,22 @@ class ConvKind(enum.Enum):
}


class StrideSupport(enum.Enum):
Strided = enum_auto()
Unity = enum_auto()


StrideSupportTag = {
StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided',
StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity',
}

StrideSupportNames = {
StrideSupport.Strided: '',
StrideSupport.Unity: 'unity_stride',
}


class IteratorAlgorithm(enum.Enum):
Analytic = enum_auto()
Optimized = enum_auto()
Expand Down

0 comments on commit 37bb918

Please sign in to comment.