Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUTLASS] Support more kernels: int8, tf32, and 3xtf32 #9899

Merged
merged 22 commits into from
Jan 13, 2022

Conversation

masahi
Copy link
Member

@masahi masahi commented Jan 11, 2022

@comaniac @Laurawly @hwu36

The change is mostly boilerplate. Supports int8, tf32, and 3xtf32 for both gemm and conv2d. TF32 and 3xtf32 are meant for training and not expected to be fast on Geforce cards (see NVIDIA/cutlass#390). But I found that

  • FasterRCNN runs reasonably fast on RTX 3070 with TF32 (about 28 msec on (1, 3, 512, 512) input).
  • MaskRCNN with 3xtf32 results match with PyTorch fp32 reference (atol = rtol = 1e-5).

For int8, I tested on quantized resnet50. But for some reason, the performance is worse than fp16. I haven't investigated deeply yet. In general, int8 models available for benchmarking are quite limited.

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review. LGTM.

@comaniac comaniac merged commit ff2c434 into apache:main Jan 13, 2022
@comaniac
Copy link
Contributor

Thanks @masahi

@hwu36
Copy link

hwu36 commented Jan 14, 2022

@masahi, cutlass profiler misses many 256x64, 64x256 tile sizes for turing/volta kernels. It needs some change like

--- a/tools/library/scripts/generator.py
+++ b/tools/library/scripts/generator.py
@@ -724,6 +724,8 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version):
       TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
       TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
       TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
+      TileDescription([256,  64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc),
+      TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc),
       TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
       TileDescription([128,  64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
       TileDescription([ 64,  64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
@@ -941,6 +943,8 @@ def GenerateSM75_TensorOp_1688(manifest, cuda_version):
       TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
       TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
       TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
+      TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc),
+      TileDescription([256,  64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc),
       TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
       TileDescription([128,  64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
       TileDescription([ 64,  64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
@@ -1218,6 +1222,8 @@ def GenerateSM75_TensorOp_8832_TN(manifest, cuda_version):
       TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc),
       TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc),
       TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
+      TileDescription([256,  64, 128], 2, [4, 1, 1], math_inst, min_cc, max_cc),
+      TileDescription([ 64, 256, 128], 2, [1, 4, 1], math_inst, min_cc, max_cc),
       TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
       TileDescription([128,  64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
       TileDescription([ 64,  64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
@@ -1355,6 +1361,8 @@ def GenerateSM75_TensorOp_88128(manifest, cuda_version):
       TileDescription([256, 128, 512], 2, [4, 2, 1], math_inst, min_cc, max_cc),
       TileDescription([128, 256, 512], 2, [2, 4, 1], math_inst, min_cc, max_cc),
       TileDescription([128, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc),
+      TileDescription([ 64, 256, 512], 2, [1, 4, 1], math_inst, min_cc, max_cc),
+      TileDescription([256,  64, 512], 2, [4, 1, 1], math_inst, min_cc, max_cc),
       TileDescription([ 64, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc),
       TileDescription([128,  64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc),
       TileDescription([ 64,  64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc),

Also, cutlass works the best with cuda 11.3+.

crazydemo pushed a commit to crazydemo/tvm that referenced this pull request Jan 27, 2022
* add int8 type in library

* wip

* adding test and plumbing data and weight dtype

* adding 3xtf32 support and refactor tile description enum

* add 3xtf32 test

* update gemm generator too

* int8 test worked

* 3xtf32 also works

* int8 and 3xtf32 gemm works

* clean up test

* support int8 in sm75

* refined int8 alignment constraints

* black

* support 3xtf32 in default kernel

* remove log

* refine dtype check

* support tf32

* leave TODO for alignment modification on int8 kernels

* tf32 test working

* fix default kernel for tf32

* workaround for compilation failure

* lint
Raghav-Chakravarthy pushed a commit to Raghav-Chakravarthy/tvm that referenced this pull request Jan 28, 2022
* add int8 type in library

* wip

* adding test and plumbing data and weight dtype

* adding 3xtf32 support and refactor tile description enum

* add 3xtf32 test

* update gemm generator too

* int8 test worked

* 3xtf32 also works

* int8 and 3xtf32 gemm works

* clean up test

* support int8 in sm75

* refined int8 alignment constraints

* black

* support 3xtf32 in default kernel

* remove log

* refine dtype check

* support tf32

* leave TODO for alignment modification on int8 kernels

* tf32 test working

* fix default kernel for tf32

* workaround for compilation failure

* lint
ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
* add int8 type in library

* wip

* adding test and plumbing data and weight dtype

* adding 3xtf32 support and refactor tile description enum

* add 3xtf32 test

* update gemm generator too

* int8 test worked

* 3xtf32 also works

* int8 and 3xtf32 gemm works

* clean up test

* support int8 in sm75

* refined int8 alignment constraints

* black

* support 3xtf32 in default kernel

* remove log

* refine dtype check

* support tf32

* leave TODO for alignment modification on int8 kernels

* tf32 test working

* fix default kernel for tf32

* workaround for compilation failure

* lint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants