Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Improve adaptive and global pool schedule #8936

Merged
merged 4 commits into from
Sep 6, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented Sep 6, 2021

The current GPU adaptive pool schedule doesn't parallelize across H and W dimensions. This is fine for spatially small inputs as in resnet, such as (1, 1024, 7, 7). But recent architectures such as efficientnet applies adaptive (global) pooling on large inputs. In particular, efficientdet models from the TF2 detection zoo has workloads such as (1, 32, 378, 378). This results in the following sequential kernel and hence efficientdet models run extremely slow when converted to TVM.

extern "C" __global__ void __launch_bounds__(64) tvmgen_default_fused_nn_adaptive_avg_pool2d_kernel0(float* __restrict__ placeholder, float* __restrict__ tensor) {
  float tensor1[1];
  tensor1[(0)] = 0.000000e+00f;
  for (int rv0 = 0; rv0 < 378; ++rv0) {
    for (int rv1 = 0; rv1 < 378; ++rv1) {
      if (((int)threadIdx.y) < 1) {
        tensor1[(0)] = (tensor1[(0)] + placeholder[((((((((int)threadIdx.y) * 4572288) + (((int)blockIdx.x) * 1143072)) + (((int)threadIdx.x) * 142884)) + (rv0 * 378)) + rv1))]);
      }
    }
  }
  if (((int)threadIdx.y) < 1) {
    tensor[((((((int)threadIdx.y) * 32) + (((int)blockIdx.x) * 8)) + ((int)threadIdx.x)))] = (tensor1[(0)] * 6.998684e-06f);
  }
}

I made two modifications to the adaptive pool schedule:

  • For the common case where the output size is (1, 1) (aka global pooling), adaptive pooling is equivalent to sum(..., axis=[2, 3]), for example. So the existing GPU reduction schedule should be used as is.
  • For other cases, simply parallelize across all axes. For example, when the output shape is (N, C, pool_x, pool_y), we simply create N * C * pool_x * pool_y parallel work. Each thread would compute reduction over the corresponding input subwindow.

Performance results on CUDA + Geforce MX250 (laptop GPU)
All numbers in milli seconds

Workload out size TVM before TVM after Torch
(1, 1024, 7, 7) (1, 1) 0.031 0.029 0.059
(1, 32, 378, 378) (1, 1) 4.48 0.84 0.66
(1, 32, 378, 378) (16, 16) 1.31 0.39 1.51

@comaniac comaniac merged commit 054e2bb into apache:main Sep 6, 2021
@comaniac
Copy link
Contributor

comaniac commented Sep 6, 2021

Thanks @masahi @vinx13

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants