Skip to content

Commit

Permalink
Merge pull request #6064 from wangbluo/fix_attn
Browse files Browse the repository at this point in the history
[sp] : fix the attention kernel for sp
  • Loading branch information
wangbluo authored Sep 18, 2024
2 parents 37e3523 + 10e4f7d commit 63314ce
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size
)

if size >= MEMORY_BOUND:
ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load()
if isinstance(ColoAttention._flash_kernel_dispatch, KernelLoader):
ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load()
# lazy load
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
Expand Down

0 comments on commit 63314ce

Please sign in to comment.