Skip to content

Commit

Permalink
add profiler boilarplate for conv2d
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Nov 13, 2021
1 parent ca1ae27 commit f7c3b5a
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 8 deletions.
86 changes: 85 additions & 1 deletion python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
"""Conv2d kernel generator and profiler for CUTLASS."""
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
from .conv2d_profiler import Conv2dProfilerEmitter
from .gemm_profiler import GemmProfilerEmitter
from gen_tensor_op import (
ProfilerEngine,
generate_sm75_tensor_op_1688,
generate_sm80_tensor_op_16816,
)
from .library import (
EpilogueFunctor,
SwizzlingFunctor,
Expand All @@ -36,7 +42,6 @@


def create_conv2d_operator(
layout,
tile_descriptions,
data_type,
alignment_constraints,
Expand All @@ -51,6 +56,7 @@ def create_conv2d_operator(
element_a, element_b, element_c, element_epilogue = data_type
iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]

layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
for tile in tile_descriptions:
for alignment in alignment_constraints:

Expand Down Expand Up @@ -105,3 +111,81 @@ def create_conv2d_operator(
ret.append(op_entry)

return ret


GENERATOR_FUNC_TABLE = {
75: generate_sm75_tensor_op_1688,
80: generate_sm80_tensor_op_16816,
}


# TODO(masahi): A sensible way to pick reasonable default kernels
DEFAULT_KERNELS = {
75: {
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align4",
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align4",
},
80: {
"float16": "cutlass_tensorop_h16816gemm_128x256_32x3_tn_align4",
"float32": "cutlass_tensorop_s16816gemm_f16_128x128_32x3_tn_align4",
},
}


class CutlassConv2DProfiler:
"""Profile all candidate kernels and select the best one."""

def __init__(self, sm, cutlass_path, binary_path):
assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, "sm%d not supported yet." % sm
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
self.sm = sm
self.cache = {}

def check_align(self, op_name, M):
"""Filter out kernels that cannot be supported."""
# TODO
return True

def get_default(self, out_dtype):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype, op_creator=create_conv2d_operator
)
default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
assert len(filtered) == 1
return filtered[0]

def profile(
self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False
):
"""Profile and select the best kernel from candidate kernels.
If profile_all is False, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
if (M, N, K) in self.cache:
return self.cache[(M, N, K)]

ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype, op_creator=create_conv2d_operator
)
ops = list(filter(lambda op: self.check_align(op["name"], M), ops))

for op in ops:
op["runtime"] = -1

if profile_all:
self.engine.compile_all(ops, use_multiprocessing)

for op in ops:
out = self.engine.evaluate(op, [M, N, K])
op["runtime"] = out
if out > 0 and profile_all is False:
break

valid_ops = filter(lambda op: op["runtime"] > 0, ops)
output = sorted(valid_ops, key=lambda i: i["runtime"])
self.cache[(M, N, K)] = output[0]
return output[0]
8 changes: 5 additions & 3 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@


def _create_gemm_operator(
layouts,
tile_descriptions,
data_type,
alignment_constraints,
Expand All @@ -60,6 +59,10 @@ def _create_gemm_operator(
if batched:
swizzling_functor = SwizzlingFunctor.Batched

layouts = [
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
]

for layout in layouts:
for tile_description in tile_descriptions:
for alignment in alignment_constraints:
Expand Down Expand Up @@ -135,15 +138,14 @@ def _create_gemm_operator(


def create_gemm_operator(batched):
# TODO: replace with partial
def op_creator(
layouts,
tile_descriptions,
data_type,
alignment_constraints,
swizzling_functor=SwizzlingFunctor.Identity8,
):
return _create_gemm_operator(
layouts,
tile_descriptions,
data_type,
alignment_constraints,
Expand Down
5 changes: 1 addition & 4 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ def generate_tensor_op_common(
):
"""Common kernel generator to be used by archtecture specific generators."""
ops = []
layouts = [
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
]
for math_inst in math_instructions:
tile_descriptions = get_tile_descriptions(math_inst)
data_type = [
Expand All @@ -55,7 +52,7 @@ def generate_tensor_op_common(
math_inst.element_accumulator,
]

out = op_creator(layouts, tile_descriptions, data_type, alignment_constraints)
out = op_creator(tile_descriptions, data_type, alignment_constraints)

ops.extend(out)

Expand Down

0 comments on commit f7c3b5a

Please sign in to comment.