Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Support multiple TIR-level dynamic shared memory allocations #8571

Merged
merged 21 commits into from
Jul 31, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented Jul 28, 2021

A follow-up to #8466

A new pass is added to merge multiple TIR-level dynamic shared memory allocations, whose sizes may not be a constant. This case is not handled by storage_rewrite pass. Rather than updating storage_rewrite pass, I added a new pass since the logic is simpler (we MUST merge and we know which alloc to merge).

Hetero-dtype is supported per discussion #8466 (comment)

@tqchen @vinx13 @yzh119

@tqchen
Copy link
Member

tqchen commented Jul 28, 2021

cc @Hzfengsy @vinx13 would be great if you can help to manage this PR

@masahi
Copy link
Member Author

masahi commented Jul 29, 2021

For the dyn shmem matmul test, the generated kernel looks like:

extern "C" __global__ void default_function_kernel0(half* __restrict__ A, half* __restrict__ B, float* __restrict__ reduce) {
  extern __shared__ uchar buf_dyn_shmem[];
  ((float*)buf_dyn_shmem)[((((((int)threadIdx.y) * 16) + ((int)threadIdx.x)) + 512))] = 0.000000e+00f;
  for (int i = 0; i < 64; ++i) {
    ((half*)buf_dyn_shmem)[((((((int)threadIdx.y) * 16) + ((int)threadIdx.x)) + 512))] = A[(((((((int)blockIdx.y) * 16384) + (((int)threadIdx.y) * 1024)) + (i * 16)) + ((int)threadIdx.x)))];
    ((half*)buf_dyn_shmem)[(((((int)threadIdx.y) * 16) + ((int)threadIdx.x)))] = B[(((((i * 16384) + (((int)threadIdx.y) * 1024)) + (((int)blockIdx.x) * 16)) + ((int)threadIdx.x)))];
    __syncthreads();
    for (int k = 0; k < 16; ++k) {
      ((float*)buf_dyn_shmem)[((((((int)threadIdx.y) * 16) + ((int)threadIdx.x)) + 512))] = (((float*)buf_dyn_shmem)[((((((int)threadIdx.y) * 16) + ((int)threadIdx.x)) + 512))] + ((float)(((half*)buf_dyn_shmem)[((((((int)threadIdx.y) * 16) + k) + 512))] * ((half*)buf_dyn_shmem)[(((k * 16) + ((int)threadIdx.x)))])));
    }
    __syncthreads();
  }
  reduce[(((((((int)blockIdx.y) * 16384) + (((int)threadIdx.y) * 1024)) + (((int)blockIdx.x) * 16)) + ((int)threadIdx.x)))] = ((float*)buf_dyn_shmem)[((((((int)threadIdx.y) * 16) + ((int)threadIdx.x)) + 512))];
}

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse buffers that are out of the lifetime in the new pass? To be specific, please see the following example:

A_shared[i] = A[i]
A_local[i] = A_shared[i]
C_local[i] = A_local[i] + 1
C_shared[i] = C_local[i]

Since A_shared[i] is never used when we store to C_shared. We can directly store the data into A_shared[i] to reduce memory usage. It is supported in storage_rewrite

tests/python/unittest/test_tir_ir_builder.py Outdated Show resolved Hide resolved
tests/python/unittest/test_tir_ir_builder.py Outdated Show resolved Hide resolved
@masahi
Copy link
Member Author

masahi commented Jul 29, 2021

Can we reuse buffers that are out of the lifetime in the new pass? To be specific, please see the following example:

A_shared[i] = A[i]
A_local[i] = A_shared[i]
C_local[i] = A_local[i] + 1
C_shared[i] = C_local[i]

Since A_shared[i] is never used when we store to C_shared. We can directly store the data into A_shared[i] to reduce memory usage. It is supported in storage_rewrite

Thanks, I didn't think about reuse support. To support this, I think it is better to drop the new pass in this PR and merge the new functionality to storage_rewrite pass. I'll consider both possibilities and try to find the simplest solution.

One difficulty I can imagine is that, dynamic shared memory in general has unknown alloc size. So for the general cases I don't think reuse analysis would work just like it does in storage_rewrite. For special cases where dynamic shared memory happens to have a constant size, it is probably worth supporting reuse. What do you think? @Hzfengsy @vinx13

@masahi
Copy link
Member Author

masahi commented Jul 30, 2021

I think we can use storage_rewrite to support buffer reuse on dynamic shared memory with constant sizes, and then use the new pass in this PR to merge remaining buffers. I'll pursue this approach.

@vinx13
Copy link
Member

vinx13 commented Jul 30, 2021

I think we can use storage_rewrite to support buffer reuse on dynamic shared memory with constant sizes, and then use the new pass in this PR to merge remaining buffers. I'll pursue this approach.

I agree. For constant sizes, storage_rewrite should be able to eliminate buffer allocations, running MergeDynamicMemoryAlloc after StorageRewritewill work

@masahi
Copy link
Member Author

masahi commented Jul 30, 2021

ok @Hzfengsy @vinx13 I added a new test that demonstrates storage_rewrite and the new merge pass working seamlessly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants