Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUTLASS] Conv2d dgrad #10110

Merged
merged 7 commits into from
Feb 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm.contrib.nvcc import find_cuda_path, get_cuda_version
from .gen_gemm import CutlassGemmProfiler
from .gen_conv2d import CutlassConv2DProfiler
from .library import ConvKind

logger = logging.getLogger("cutlass")

Expand Down Expand Up @@ -86,7 +87,7 @@ def visit_call(self, call):
self.signature["ret_dtype"] = op.ret_type.dtype
self.visit(op.body)

if str(op) == "nn.conv2d":
if str(op) in ["nn.conv2d", "nn.conv2d_transpose", "nn.conv2d_backward_weight"]:
self.op_attrs = call.attrs

for arg in call.args:
Expand Down Expand Up @@ -242,8 +243,17 @@ def handle_conv2d(
use_multiprocessing,
):
"""Profile and select a kernel for conv2d op workload."""
if "conv2d_transpose" in op_type:
conv_kind = ConvKind.Dgrad
elif "backward_weight" in op_type:
conv_kind = ConvKind.Wgrad
else:
conv_kind = ConvKind.Fprop

if any(isinstance(s, tvm.tir.Any) for s in d_shape):
out = cutlass_profiler.get_default(op_type, out_dtype, data_dtype, weight_dtype, use_3xtf32)
out = cutlass_profiler.get_default(
op_type, out_dtype, data_dtype, weight_dtype, use_3xtf32, conv_kind, strides
)
name, cutlass_op_def = out["name"], out["opdef"]
logger.info("Picked the default kernel %s", name)
else:
Expand All @@ -258,6 +268,7 @@ def handle_conv2d(
data_dtype,
weight_dtype,
use_3xtf32,
conv_kind,
profile_all_alignments,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
Expand Down Expand Up @@ -329,6 +340,7 @@ def tune_cutlass_kernels(
if "cutlass" in fun_name:
num_cutlass_partition += 1
annotator.visit(func)
out_shape = annotator.signature["ret_shape"]
out_dtype = annotator.signature["ret_dtype"]
op_type = annotator.signature["op_type"]

Expand All @@ -344,12 +356,23 @@ def tune_cutlass_kernels(
new_attrs["padding"] = annotator.op_attrs.padding
new_attrs["strides"] = annotator.op_attrs.strides
new_attrs["dilation"] = annotator.op_attrs.dilation

if "conv2d_transpose" in op_type:
d_shape = out_shape
w_shape = arg1_shape
elif "conv2d_backward_weight" in op_type:
d_shape = arg1_shape
w_shape = out_shape
else:
d_shape = arg0_shape
w_shape = arg1_shape

new_attrs.update(
handle_conv2d(
conv2d_profiler,
op_type,
arg0_shape,
arg1_shape,
d_shape,
w_shape,
annotator.op_attrs.padding,
annotator.op_attrs.strides,
annotator.op_attrs.dilation,
Expand Down
18 changes: 13 additions & 5 deletions python/tvm/contrib/cutlass/conv2d_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def __init__(self):
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/conv/kernel/default_conv2d_wgrad.h"
#include "cutlass/conv/kernel/default_conv2d_dgrad.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
Expand Down Expand Up @@ -89,11 +91,6 @@ def __init__(self):
using ElementOutput = typename ImplicitGemm::ElementC;
using ElementInputA = typename ImplicitGemm::ElementA;
using ElementInputB = typename ImplicitGemm::ElementB;
auto oshape = options.output_size();
cutlass::HostTensor<ElementInputA, typename ImplicitGemm::LayoutA> tensor_a(options.input_size);
cutlass::HostTensor<ElementInputB, typename ImplicitGemm::LayoutB> tensor_b(options.filter_size);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_c(oshape);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_ref_c(oshape);

cutlass::conv::Conv2dProblemSize problem_size(
options.input_size,
Expand All @@ -106,7 +103,18 @@ def __init__(self):
1
);

auto conv_kind = ImplicitGemm::kConvolutionalOperator;
auto a_extent = implicit_gemm_tensor_a_extent(conv_kind, problem_size);
auto b_extent = implicit_gemm_tensor_b_extent(conv_kind, problem_size);
auto c_extent = implicit_gemm_tensor_c_extent(conv_kind, problem_size);

cutlass::HostTensor<ElementInputA, typename ImplicitGemm::LayoutA> tensor_a(a_extent);
cutlass::HostTensor<ElementInputB, typename ImplicitGemm::LayoutB> tensor_b(b_extent);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_c(c_extent);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_ref_c(c_extent);

using ElementComputeEpilogue = typename ImplicitGemm::ElementCompute;

typename ImplicitGemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
Expand Down
71 changes: 62 additions & 9 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Conv2d kernel generator and profiler for CUTLASS."""
from functools import partial
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
from .gen_gemm import CutlassGemmProfiler
from .conv2d_profiler import Conv2dProfilerEmitter
Expand All @@ -32,7 +33,13 @@


def create_conv2d_operator_with_epilogue(
op_type, tile_description, data_type, alignment, swizzling_functor
conv_kind,
stride_support,
op_type,
tile_description,
data_type,
alignment,
swizzling_functor,
):
"""
Instantiate a cutlass kernel from the given configuration,
Expand Down Expand Up @@ -72,15 +79,15 @@ def create_conv2d_operator_with_epilogue(
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)

op = Conv2dOperation(
ConvKind.Fprop,
conv_kind,
IteratorAlgorithm.Optimized,
tile_description.minimum_compute_capability,
tile_description,
A,
B,
C,
element_epilogue,
StrideSupport.Strided,
stride_support,
epilogue,
swizzling_functor,
)
Expand All @@ -94,6 +101,8 @@ def create_conv2d_operator_with_epilogue(


def enumerate_conv2d_operators(
conv_kind,
stride_support,
tile_descriptions,
data_type,
alignment_constraints,
Expand All @@ -107,6 +116,9 @@ def enumerate_conv2d_operators(

element_a, element_b, element_c, element_epilogue = data_type

if conv_kind == ConvKind.Dgrad and stride_support == StrideSupport.Strided:
swizzling_functor = SwizzlingFunctor.StridedDgradIdentity1

for tile in tile_descriptions:
for alignment in alignment_constraints:

Expand All @@ -115,15 +127,15 @@ def enumerate_conv2d_operators(
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)

op = Conv2dOperation(
ConvKind.Fprop,
conv_kind,
IteratorAlgorithm.Optimized,
tile.minimum_compute_capability,
tile,
A,
B,
C,
element_epilogue,
StrideSupport.Strided,
stride_support,
EpilogueFunctor.LinearCombination,
swizzling_functor,
)
Expand Down Expand Up @@ -152,7 +164,16 @@ def __init__(self, sm, cutlass_path, binary_path):
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
self.cache = {}

def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32):
def get_default(
self,
op_type,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
conv_kind=ConvKind.Fprop,
stride=(1, 1),
):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
Expand All @@ -162,8 +183,21 @@ def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32):
tile_description = gemm_profile_result["tile_description"]
alignment = gemm_profile_result["alignment"]
data_type = gemm_profile_result["data_type"]
stride_support = StrideSupport.Strided if stride[0] > 1 else StrideSupport.Unity

if conv_kind == ConvKind.Dgrad and stride_support == StrideSupport.Strided:
swizzling_functor = SwizzlingFunctor.StridedDgradIdentity1
else:
swizzling_functor = SwizzlingFunctor.Identity4

name, opdef = create_conv2d_operator_with_epilogue(
op_type, tile_description, data_type, alignment, SwizzlingFunctor.Identity4
conv_kind,
stride_support,
op_type,
tile_description,
data_type,
alignment,
swizzling_functor,
)
return {"name": name, "opdef": opdef}

Expand All @@ -178,6 +212,8 @@ def select_op(
data_dtype,
weight_dtype,
use_3xtf32,
conv_kind,
stride_support,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
Expand All @@ -188,6 +224,7 @@ def select_op(
"""
N, H, W, IC = d_shape
OC, R, S, _ = w_shape

workload = (
N,
H,
Expand All @@ -211,7 +248,7 @@ def select_op(
out_dtype,
data_dtype,
weight_dtype,
enumerate_conv2d_operators,
partial(enumerate_conv2d_operators, conv_kind, stride_support),
lambda align: all([dim % align == 0 for dim in [IC, OC]]),
use_3xtf32,
profile_all_alignments,
Expand Down Expand Up @@ -248,6 +285,7 @@ def profile(
data_dtype,
weight_dtype,
use_3xtf32=True,
conv_kind=ConvKind.Fprop,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
Expand All @@ -256,6 +294,13 @@ def profile(
If find_first_valid is True, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
# Dgrad requires Unity stride when stride == (1, 1)
stride_support = (
StrideSupport.Unity
if stride[0] == 1 and stride[1] == 1 and conv_kind == ConvKind.Dgrad
else StrideSupport.Strided
)

op = self.select_op(
d_shape,
w_shape,
Expand All @@ -266,13 +311,21 @@ def profile(
data_dtype,
weight_dtype,
use_3xtf32,
conv_kind,
stride_support,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
)

name, opdef = create_conv2d_operator_with_epilogue(
op_type, op["tile_description"], op["data_type"], op["alignment"], op["swizzle_functor"]
conv_kind,
stride_support,
op_type,
op["tile_description"],
op["data_type"],
op["alignment"],
op["swizzle_functor"],
)

return name, opdef, op["runtime"]
1 change: 1 addition & 0 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def get_tile_descriptions(math_inst):
"cutlass.conv2d_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True),
"cutlass.conv2d_bias": (EpilogueFunctor.LinearCombinationBias, True),
"cutlass.conv2d": (EpilogueFunctor.LinearCombination, False),
"cutlass.conv2d_transpose": (EpilogueFunctor.LinearCombination, False),
}


Expand Down
10 changes: 10 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ class SwizzlingFunctor(enum.Enum):
Identity4 = enum_auto()
Identity8 = enum_auto()
Batched = enum_auto()
StridedDgradIdentity1 = enum_auto()
StridedDgradIdentity4 = enum_auto()


SwizzlingFunctorTag = {
Expand All @@ -197,20 +199,28 @@ class SwizzlingFunctor(enum.Enum):
SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>",
SwizzlingFunctor.Batched: "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle",
SwizzlingFunctor.StridedDgradIdentity1: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>",
SwizzlingFunctor.StridedDgradIdentity4: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>",
}


class ConvKind(enum.Enum):
Fprop = enum_auto()
Dgrad = enum_auto()
Wgrad = enum_auto()


ConvKindTag = {
ConvKind.Fprop: "cutlass::conv::Operator::kFprop",
ConvKind.Dgrad: "cutlass::conv::Operator::kDgrad",
ConvKind.Wgrad: "cutlass::conv::Operator::kWgrad",
}


ConvKindNames = {
ConvKind.Fprop: "fprop",
ConvKind.Dgrad: "dgrad",
ConvKind.Wgrad: "wgrad",
}


Expand Down
Loading