From d1dafbd6d050ddcd673f567c5ff720a6b6419cfe Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 14 Dec 2021 12:58:23 +0900 Subject: [PATCH] [CUTLASS] More robust support for pattern matching and alignment (#9698) * bug fix in im2col encoding * skip legalize when batch size is dynamic * add sm75 kernels to sm80 profilings * add dtype and layout check in parttern match * use align1 kernel for unusual channel cases (IC = 3 etc) * test IC=3 convolution * fixed check functions for fused cases, run infer type before mergecomposite * check align on N dim * add comment on IC == 3 case * lint fix * do not offload depthwise conv2d * lint * trigger CI --- python/tvm/contrib/cutlass/gen_conv2d.py | 12 +-- python/tvm/contrib/cutlass/gen_gemm.py | 21 +++--- python/tvm/contrib/cutlass/gen_tensor_op.py | 4 +- python/tvm/relay/op/contrib/cutlass.py | 81 ++++++++++++++++++--- python/tvm/topi/cuda/conv2d_alter_op.py | 4 + tests/python/contrib/test_cutlass.py | 71 +++++++++++------- 6 files changed, 139 insertions(+), 54 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index d24e988ebe35..5a616c9b6e02 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -128,16 +128,16 @@ def profile( If profile_all is False, return immediately after the first applicable kernel is found. If use_multiprocessing is True, compile all profiler executables in parallel. """ - B, H, W, C = d_shape - K, R, S, _ = w_shape + B, _, _, IC = d_shape + OC, R, S, _ = w_shape _, P, Q, _ = out_shape - M = B * H * W - K = R * S * C - N = B * P * Q + M = B * P * Q + N = OC + K = R * S * IC gemm_profile_result = self.gemm_profiler.profile( - M, K, N, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing + M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing ) tile_description = gemm_profile_result["tile_description"] diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index cec64f0af974..58d690f8191b 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -141,12 +141,13 @@ def create_gemm_operator( # TODO(masahi): A sensible way to pick reasonable default kernels DEFAULT_KERNELS = { 75: { - "float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align4", - "float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align4", + "float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1", + "float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1", }, + # align1 variants do not seem to be available for sm80 80: { - "float16": "cutlass_tensorop_h16816gemm_128x256_32x3_tn_align4", - "float32": "cutlass_tensorop_s16816gemm_f16_128x128_32x3_tn_align4", + "float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1", + "float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1", }, } @@ -160,14 +161,16 @@ def __init__(self, sm, cutlass_path, binary_path): self.sm = sm self.cache = {} - def check_align(self, op_name, M): + def check_align(self, op_name, M, N, K): """Filter out kernels that cannot be supported.""" aligns = re.findall(r"align[1|2|4|8]", op_name) assert len(aligns) == 1 + # The same alignment is used for all axes align = int(aligns[0][-1]) - if M % align != 0: - return False - return True + # TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive. + # See https://github.com/NVIDIA/cutlass/issues/362. + # When the above issue is resolved, we can remove the alignment check on M below. + return all([dim % align == 0 for dim in [M, N, K]]) def get_default(self, out_dtype, batched=False): """Return the default kernel for the requested architecture. @@ -194,7 +197,7 @@ def profile( ops = GENERATOR_FUNC_TABLE[self.sm]( out_dtype, op_creator=partial(create_gemm_operator, batched=batched) ) - ops = list(filter(lambda op: self.check_align(op["name"], M), ops)) + ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops)) for op in ops: op["runtime"] = -1 diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index c8221514ce0a..cc228737cefc 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -152,9 +152,11 @@ def get_tile_descriptions(math_inst): TileDescription([64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), ] - return generate_tensor_op_common( + sm75_kernels = generate_sm75_tensor_op_1688(out_dtype, op_creator) + sm80_kernels = generate_tensor_op_common( math_instructions, alignment_constraints, get_tile_descriptions, op_creator ) + return sm75_kernels + sm80_kernels class ProfilerEngine: diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 4ae529e18dc2..0a67581400ed 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """Patterns supported CUTLASS.""" +from tvm.ir.transform import Sequential from tvm.relay import transform from ...dataflow_pattern import wildcard, is_op, is_constant @@ -56,19 +58,71 @@ def make_batch_matmul_pattern(): def make_conv2d_pattern(): - # TODO(masahi): Check layout and alignment return is_op("nn.conv2d")(wildcard(), wildcard()) +def check_dtype(lhs, rhs): + """Check if dtypes in the given workload are supported by CUTLASS.""" + # Only fp16 inputs are supported for now. + return lhs.dtype == rhs.dtype and lhs.dtype == "float16" and rhs.dtype == "float16" + + +def get_root_call(call, root_op_name): + if str(call.op) == root_op_name: + return call + return get_root_call(call.args[0], root_op_name) + + +def check_gemm(call): + """Check if the given dense workload can be offloaded to CUTLASS.""" + dense = get_root_call(call, "nn.dense") + lhs = dense.args[0].checked_type + rhs = dense.args[1].checked_type + return check_dtype(lhs, rhs) + + +def check_batch_matmul(call): + """Check if the given batch_matmul workload can be offloaded to CUTLASS.""" + batch_matmul = get_root_call(call, "nn.batch_matmul") + lhs = batch_matmul.args[0].checked_type + rhs = batch_matmul.args[1].checked_type + transpose_a = batch_matmul.attrs.transpose_a + transpose_b = batch_matmul.attrs.transpose_b + return check_dtype(lhs, rhs) and not transpose_a and transpose_b + + +def is_depthwise_conv2d(ic, oc, groups): + return ic == oc == groups + + +def check_conv2d(call): + """Check if the given conv2d workload can be offloaded to CUTLASS.""" + conv2d = get_root_call(call, "nn.conv2d") + data_layout = conv2d.attrs.data_layout + kernel_layout = conv2d.attrs.kernel_layout + data = conv2d.args[0].checked_type + weight = conv2d.args[1].checked_type + if data_layout != "NHWC" or kernel_layout != "OHWI" or not check_dtype(data, weight): + return False + IC = data.shape[3] + OC = weight.shape[0] + return not is_depthwise_conv2d(IC, OC, conv2d.attrs.groups) + + def partition_for_cutlass(mod): """Partition the input module into CUTLASS-supported subgraphs.""" - dense_pat = ("cutlass.dense", make_gemm_pattern(False, None)) - dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None)) - dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu")) - dense_bias_gelu_fp16_pat = ("cutlass.dense_bias_gelu_fp16", make_gemm_pattern(True, "gelu")) + dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm) + dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm) + dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"), check_gemm) + dense_bias_gelu_fp16_pat = ( + "cutlass.dense_bias_gelu_fp16", + make_gemm_pattern(True, "gelu"), + check_gemm, + ) dense_bias_gelu_fp32_pat = ( "cutlass.dense_bias_gelu_fp32", make_gemm_pattern(True, "gelu", out_dtype="float32"), + check_gemm, ) cutlass_patterns = [ dense_bias_gelu_fp16_pat, @@ -76,11 +130,16 @@ def partition_for_cutlass(mod): dense_bias_relu_pat, dense_bias_pat, dense_pat, - ("cutlass.batch_matmul", make_batch_matmul_pattern()), + ("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul), # TODO(masahi): Add more conv2d patterns - ("cutlass.conv2d", make_conv2d_pattern()), + ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d), ] - mod = transform.MergeComposite(cutlass_patterns)(mod) - mod = transform.AnnotateTarget(["cutlass"])(mod) - mod = transform.PartitionGraph()(mod) - return mod + seq = Sequential( + [ + transform.InferType(), + transform.MergeComposite(cutlass_patterns), + transform.AnnotateTarget(["cutlass"]), + transform.PartitionGraph(), + ] + ) + return seq(mod) diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index 3d05058ff52c..e6631d57b29e 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -450,6 +450,10 @@ def _conv2d_legalize(attrs, inputs, arg_types): elif data_dtype in ["float16"]: if data_layout == "NHWC" and kernel_layout == "HWIO": + if isinstance(data_tensor.shape[0], tvm.tir.expr.Any): + # Skip legalize when the batch size is dynamic + return None + batch = data_tensor.shape[0].value in_channel = data_tensor.shape[3].value out_channel = kernel_tensor.shape[3].value diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index a258da3c5d78..585b42a21425 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -242,6 +242,8 @@ def verify_batch_matmul( def test_dense(): verify_dense(get_dense(M, N, K), M, N, K) verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K) + # Test align1 case + verify_dense(get_dense_bias(M, N + 1, K), M, N + 1, K) def test_dense_bias(): @@ -312,13 +314,14 @@ def convert_conv2d_layout(mod, desired_layouts): def verify_conv2d( - mod_nchw, - mod_ref, + mod_nchw, # can be dynamic batch + mod_ref, # always static batch d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, + use_cudnn_ref=False, run_benchmark=False, ): if not has_cutlass(): @@ -332,14 +335,14 @@ def verify_conv2d( typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape) + mod_weight_ohwi = convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}) + if use_vm: - rt_mod, dev, num_cutlass_partition = profile_and_build_vm( - convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}), params, sm - ) + rt_mod, _, num_cutlass_partition = profile_and_build_vm(mod_weight_ohwi, params, sm) out = get_output_vm(rt_mod, ["data"], [np_data]) else: - rt_mod, dev, num_cutlass_partition = profile_and_build( - convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}), + rt_mod, _, num_cutlass_partition = profile_and_build( + mod_weight_ohwi, params, sm, ) @@ -347,37 +350,51 @@ def verify_conv2d( assert num_cutlass_partition > 0 - rt_mod_ref, _ = get_ref_rt_mod( - convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}), - params, - target="cuda", - ) - ref_out = get_output(rt_mod_ref, ["data"], [np_data]) + if use_cudnn_ref: + rt_mod_ref, dev = get_ref_rt_mod( + convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "OHWI"]}), + params, + target="cuda -libs=cudnn", + ) + else: + rt_mod_ref, dev = get_ref_rt_mod( + convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}), + params, + target="cuda", + ) - np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol) + ref_out = get_output(rt_mod_ref, ["data"], [np_data]) if run_benchmark: print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600)) print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600)) + np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol) + def test_conv2d(): + for IC in [3, 16]: + d_shape = (16, IC, 32, 32) + w_shape = (32, IC, 3, 3) + mod_nchw = get_conv2d_nchw(d_shape, w_shape) + + verify_conv2d( + mod_nchw, + mod_nchw, + d_shape, + w_shape, + sm=80, + atol=1e-5, + rtol=1e-5, + use_cudnn_ref=(IC == 3), # The autotvm kernel has an accuracy issue with IC == 3 case + run_benchmark=False, + ) + d_shape = (16, 16, 32, 32) w_shape = (32, 16, 3, 3) - mod_nchw = get_conv2d_nchw(d_shape, w_shape) - - verify_conv2d( - mod_nchw, - mod_nchw, - d_shape, - w_shape, - sm=80, - atol=1e-5, - rtol=1e-5, - run_benchmark=False, - ) - dyn_batch_shape = (relay.Any(),) + d_shape[1:] + + mod_nchw = get_conv2d_nchw(d_shape, w_shape) mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape) verify_conv2d(