Skip to content

Commit

Permalink
improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 4, 2022
1 parent 0bce8f3 commit 2eb1cf4
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 10 deletions.
5 changes: 5 additions & 0 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def handle_conv2d(
out_dtype,
data_dtype,
weight_dtype,
split_k_slices,
use_3xtf32,
profile_all_alignments,
find_first_valid,
Expand Down Expand Up @@ -269,6 +270,7 @@ def handle_conv2d(
weight_dtype,
use_3xtf32,
conv_kind,
split_k_slices,
profile_all_alignments,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
Expand Down Expand Up @@ -367,6 +369,8 @@ def tune_cutlass_kernels(
d_shape = arg0_shape
w_shape = arg1_shape

split_k_slices = [8]

new_attrs.update(
handle_conv2d(
conv2d_profiler,
Expand All @@ -380,6 +384,7 @@ def tune_cutlass_kernels(
arg0_dtype,
arg1_dtype,
use_3xtf32,
split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def procedural_name(self):
"_${layout}_align${alignment}"
)

if self.split_k_slices > 1:
configuration_name += "_splitk%d" % self.split_k_slices

return substitute_template(
configuration_name,
{
Expand Down Expand Up @@ -210,6 +213,7 @@ def __init__(self):
${reduction}
"""

self.reduction_template = """
using EpilogueOutputOp = ${epilogue};
using ReductionOp = cutlass::reduction::thread::ReduceAdd<
Expand Down
11 changes: 8 additions & 3 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def create_conv2d_operator_with_epilogue(
data_type,
alignment,
swizzling_functor,
split_k_slices=8,
split_k_slices,
):
"""
Instantiate a cutlass kernel from the given configuration,
Expand Down Expand Up @@ -105,11 +105,11 @@ def create_conv2d_operator_with_epilogue(
def enumerate_conv2d_operators(
conv_kind,
stride_support,
split_k_slices,
tile_descriptions,
data_type,
alignment_constraints,
swizzling_functor=SwizzlingFunctor.Identity4,
split_k_slices=[8],
):
"""Exhaustively instantiate all kernels from a given configuration."""
ret = []
Expand Down Expand Up @@ -203,6 +203,7 @@ def get_default(
data_type,
alignment,
swizzling_functor,
split_k_slices=1
)
return {"name": name, "opdef": opdef}

Expand All @@ -219,6 +220,7 @@ def select_op(
use_3xtf32,
conv_kind,
stride_support,
split_k_slices,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
Expand Down Expand Up @@ -253,7 +255,7 @@ def select_op(
out_dtype,
data_dtype,
weight_dtype,
partial(enumerate_conv2d_operators, conv_kind, stride_support),
partial(enumerate_conv2d_operators, conv_kind, stride_support, split_k_slices),
lambda align: all([dim % align == 0 for dim in [IC, OC]]),
use_3xtf32,
profile_all_alignments,
Expand Down Expand Up @@ -293,6 +295,7 @@ def profile(
weight_dtype,
use_3xtf32=True,
conv_kind=ConvKind.Fprop,
split_k_slices=[1],
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
Expand Down Expand Up @@ -320,6 +323,7 @@ def profile(
use_3xtf32,
conv_kind,
stride_support,
split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
Expand All @@ -333,6 +337,7 @@ def profile(
op["data_type"],
op["alignment"],
op["swizzle_functor"],
op["split_k_slices"],
)

return name, opdef, op["runtime"]
21 changes: 14 additions & 7 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,13 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
op_type != "cutlass.conv2d_bias_silu" &&
op_type != "cutlass.conv2d_bias_hardswish";

const std::string op_name = attrs.at("op_name");
std::ostringstream conv2d_decl;
CutlassPrint(conv2d_decl, attrs.at("op_def"));
CutlassPrint(conv2d_decl, "using Operation_" + attrs.at("op_name") +
CutlassPrint(conv2d_decl, "using Operation_" + op_name +
" = cutlass::conv::device::ImplicitGemmConvolution<" +
attrs.at("op_name") + ">;\n");
CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + attrs.at("op_name") + ";\n");
op_name + ">;\n");
CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + op_name + ";\n");
CutlassPrint(conv2d_decl, "using ElementInputA = Conv2d::ElementA;\n");
CutlassPrint(conv2d_decl, "using ElementInputB = Conv2d::ElementB;\n");
CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = Conv2d::ElementAccumulator;\n");
Expand Down Expand Up @@ -317,17 +318,23 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, "int stride_w = " + attrs.at("stride_w") + ";\n");
CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") + ";\n");
CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") + ";\n");
// TODO
const int split_k_slices = 8;
CutlassPrint(conv2d_decl, "int split_k_slices = " + std::to_string(split_k_slices) + ";\n");

const bool use_split_k = op_name.find("splitk") != std::string::npos;

if (use_split_k) {
std::string split_k_slices = op_name.substr(op_name.find_last_not_of("0123456789"));
LOG(INFO) << "split_k : " << split_k_slices;
CutlassPrint(conv2d_decl, "int split_k_slices = " + split_k_slices + ";\n");
} else {
CutlassPrint(conv2d_decl, "int split_k_slices = 1;\n");
}

CutlassPrint(
conv2d_decl,
"cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, "
"stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, "
"split_k_slices);\n");

const bool use_split_k = split_k_slices > 1;
const std::string split_k_mode = use_split_k ? "kParallel" : "kSerial";
CutlassPrint(conv2d_decl,
"const cutlass::conv::SplitKMode split_k_mode = cutlass::conv::SplitKMode::" +
Expand Down

0 comments on commit 2eb1cf4

Please sign in to comment.